Browse Source

在私有定制Sampler的情况下,多卡不替换

tags/v1.0.0alpha
yhcc 3 years ago
parent
commit
7283bf27b2
7 changed files with 175 additions and 41 deletions
  1. +5
    -2
      fastNLP/core/controllers/evaluator.py
  2. +20
    -7
      fastNLP/core/controllers/trainer.py
  3. +1
    -2
      fastNLP/core/drivers/paddle_driver/fleet.py
  4. +7
    -10
      fastNLP/core/drivers/torch_driver/ddp.py
  5. +14
    -12
      fastNLP/core/drivers/torch_driver/single_device.py
  6. +16
    -3
      fastNLP/core/drivers/torch_driver/utils.py
  7. +112
    -5
      tests/core/drivers/torch_driver/test_ddp.py

+ 5
- 2
fastNLP/core/controllers/evaluator.py View File

@@ -107,8 +107,11 @@ class Evaluator:
``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论
该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``;
* *use_dist_sampler* --
True / False, 是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持
分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集;
表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为
分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader
的 sampler 为 (a) 深度学习框架自带的默认 sampler ; (b) fastNLP 的 Sampler 等,则将替换为
:class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以
用到的数据。如果不是以上两类 sampler ,fastNLP 将报错。
* *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数;
* *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数;
* *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。


+ 20
- 7
fastNLP/core/controllers/trainer.py View File

@@ -65,6 +65,13 @@ class Trainer(TrainerEventTrigger):
您传入的 ``Driver`` 实例中的模型;

:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict;

.. warning::

当使用分布式训练时, ``fastNLP`` 会默认将 ``dataloader`` 中的 ``Sampler`` 进行处理,以使得在一个 ``epcoh`` 中,不同卡
用以训练的数据是不重叠的。如果你对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由
你自身保证每张卡上所使用的数据是不同的。

:param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List;
:param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 `torch.distributed.launch/run` 启动时可以为 None,
此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间
@@ -93,9 +100,9 @@ class Trainer(TrainerEventTrigger):

.. warning::

注意参数 ``device`` 仅当您通过 pytorch 或者其它训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``!
注意参数 ``device`` 仅当您通过训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``!

例如,当您使用::
例如,在 pytorch 中,当您使用::

python -m torch.distributed.launch --nproc_per_node 2 train.py

@@ -290,16 +297,22 @@ class Trainer(TrainerEventTrigger):
driver 实例的 ``model_device`` 才会为 None;
3. 对于 paddle,该参数无效;

* *use_dist_sampler* -- True / False, 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
* *evaluate_use_dist_sampler* -- True / False, 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``;
不传入该值时,该值与 ``use_dist_sampler`` 参数保持一致;
* *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch
内所有卡的 sample 加起来为一整个数据集的 sample,同时为了保证所有卡上拥有相同数量的 sample ,有的卡上可能会有重复的 sample ,例如
8卡训练,只有9个sample,如果batch_size为1,那么第二个batch时,有7张卡将没有 sample 可用,因此只有重复使用 sample 来 pad 到第二个
batch 中。如果不希望 fastNLP 对 dataloader 的sampler 做特殊设置,请将该值设置为 False ,若确实需要分布式的训练,请在 Trainer 外
对 train_dataloader 做的数据做特殊处理使得其在不同的卡之间 sample 是
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为
evaluate 时使用的分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader
的 sampler 为 (a) 深度学习框架自带的默认 sampler ; (b) fastNLP 的 Sampler 等,则将替换为
:class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以
用到的数据。
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";

注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``;
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 :class:`~.fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback`等对象,
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback`等对象,
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 :class:`~fastNLP.RichCallback`,否则使用 :class:`~fastNLP.RawTextCallback` 对象。如果
需要定制 progress bar 的参数,例如打印频率等,可以传入 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback` 等对象。
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。


+ 1
- 2
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -422,8 +422,7 @@ class PaddleFleetDriver(PaddleDriver):
# trainer, evaluator
if dist is None:
if reproducible:
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our "
"control.")
raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.")
else:
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):


+ 7
- 10
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -140,6 +140,9 @@ if _NEED_IMPORT_TORCH:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler
from torch.utils.data import RandomSampler as TorchRandomSampler
from torch.utils.data import SequentialSampler as TorchSequentialSampler
from torch.utils.data import BatchSampler as TorchBatchSampler

__all__ = [
'TorchDDPDriver'
@@ -159,6 +162,7 @@ from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, Unrepeated
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC
from fastNLP.core.log import logger
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
from .utils import _check_dataloader_args_for_distributed


class TorchDDPDriver(TorchDriver):
@@ -535,8 +539,7 @@ class TorchDDPDriver(TorchDriver):
# trainer, evaluator
if dist is None:
if reproducible:
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
"control.")
raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.")
else:
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
@@ -565,13 +568,7 @@ class TorchDDPDriver(TorchDriver):
)
return replace_sampler(dataloader, sampler)
else:
if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {torch.utils.data.RandomSampler,
torch.utils.data.SequentialSampler}):
raise TypeError("Using customized ``batch_sampler`` or ``sampler`` with 'DDP' may cause unseen problems, cause"
"we will substitute your dataloader's sampler into our ``fastNLP.RandomSampler``. You should make"
"your customized sampler being able to be used in distributed setting before you initialize ``Trainer`` by yourself,"
"and then set the parameter ``use_dist_sampler`` of ``Trainer`` to ``False``.")

_check_dataloader_args_for_distributed(args, controller='Trainer')
sampler = RandomSampler(
dataset=args.dataset,
shuffle=args.shuffle,
@@ -589,7 +586,7 @@ class TorchDDPDriver(TorchDriver):
if isinstance(args.sampler, ReproducibleSampler):
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
elif not isinstance(args.sampler, UnrepeatedSampler):
# todo same as dist
_check_dataloader_args_for_distributed(args, controller='Evaluator')
sampler = UnrepeatedSequentialSampler(
dataset=args.dataset
)


+ 14
- 12
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -8,6 +8,7 @@ if _NEED_IMPORT_TORCH:
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import RandomSampler as TorchRandomSampler
from torch.utils.data import SequentialSampler as TorchSequentialSampler
from torch.utils.data import BatchSampler as TorchBatchSampler

__all__ = [
'TorchSingleDriver'
@@ -123,19 +124,20 @@ class TorchSingleDriver(TorchDriver):
return replace_sampler(dataloader, sampler)

if reproducible:
if isinstance(args.sampler, TorchRandomSampler):
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'replacements', False) is False \
and getattr(args.sampler, 'generator', None) is None:
# 如果本来就是随机的,并且没有定制,直接替换掉吧。
sampler = RandomSampler(args.sampler.data_source, shuffle=True)
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
if type(args.batch_sampler) is TorchBatchSampler:
if type(args.sampler) is TorchRandomSampler:
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'replacements', False) is False \
and getattr(args.sampler, 'generator', None) is None:
# 如果本来就是随机的,并且没有定制,直接替换掉吧。
sampler = RandomSampler(args.sampler.data_source, shuffle=True)
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif type(args.sampler) is TorchSequentialSampler:
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.data_source, shuffle=False)
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif isinstance(args.sampler, TorchSequentialSampler):
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.data_source, shuffle=False)
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,


+ 16
- 3
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -22,7 +22,11 @@ if _NEED_IMPORT_TORCH:
import torch
# import torch.nn as nn
from torch.nn import Module
from torch.utils.data import DataLoader, BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler as TorchRandomSampler
from torch.utils.data import SequentialSampler as TorchSequentialSampler
from torch.utils.data import BatchSampler as TorchBatchSampler

else:
from fastNLP.core.utils.dummy_class import DummyClass as Module

@@ -200,6 +204,7 @@ def replace_sampler(dataloader: "DataLoader", sampler):
non_default_params.add("dataset")

reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None})

batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
@@ -277,5 +282,13 @@ def optimizer_state_to_device(state, device):
return new_state




def _check_dataloader_args_for_distributed(args, controller='Trainer'):
if type(args.batch_sampler) is not TorchBatchSampler and (type(args.sampler) not in {TorchRandomSampler,
TorchSequentialSampler}):
mode = 'training' if controller == 'Trainer' else 'evaluation'
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler'
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause "
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into "
f"``{substitution}``. The customized sampler should set for distributed running "
f"before initializing ``{controller}`` , and then set the "
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.")

+ 112
- 5
tests/core/drivers/torch_driver/test_ddp.py View File

@@ -12,8 +12,9 @@ from fastNLP.core.samplers import (
)
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset
from tests.helpers.utils import magic_argv_env_context
from tests.helpers.utils import magic_argv_env_context, recover_logger, Capturing
from fastNLP.envs.distributed import rank_zero_rm
from fastNLP import logger
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
@@ -936,12 +937,118 @@ def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible=T
for i in range(1, len(data)-1):
flags.append(data[i]<data[i-1])
assert all(flags) is False
datas = fastnlp_torch_all_gather(data)
if drop_last:
assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2
if dist.is_initialized():
datas = fastnlp_torch_all_gather(data)
if drop_last:
assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2
else:
assert len(set(datas[0] + datas[1])) == num_samples
finally:
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()



@pytest.mark.torch
@magic_argv_env_context
@recover_logger
@pytest.mark.parametrize("inherit", ([True, False]))
def test_customized_batch_sampler_dataloader(inherit):
try:
logger.set_stdout('raw', level='info')
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行
num_samples = 10
dataset = TorchNormalXYDataset(num_samples)
if inherit:
class BatchSampler(torch.utils.data.BatchSampler):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
indices = list(range(len(dataset)))
for i in range(len(self)):
start = i * self.batch_size
end = (i + 1) * self.batch_size
return indices[start:end]

def __len__(self):
return (len(self.dataset)+self.batch_size-1)//self.batch_size
else:
assert len(set(datas[0] + datas[1])) == num_samples
class BatchSampler:
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
indices = list(range(len(dataset)))
for i in range(len(self)):
start = i * self.batch_size
end = (i + 1) * self.batch_size
return indices[start:end]

def __len__(self):
return (len(self.dataset)+self.batch_size-1)//self.batch_size

dl = prepare_torch_dataloader(dataset, batch_sampler=BatchSampler(dataset, batch_size=4))
model = TorchNormalModel_Classification_1(10, 32)
device = [torch.device(i) for i in [0, 1]]
driver = TorchDDPDriver(model, parallel_device=device)
driver.setup()
# TODO 这里需要raise
with pytest.raises(TypeError):
dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False)
finally:
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()


@pytest.mark.torch
@magic_argv_env_context
@recover_logger
@pytest.mark.parametrize("inherit", ([True, False]))
def test_customized_sampler_dataloader(inherit):
try:
logger.set_stdout('raw', level='info')
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行
num_samples = 10
dataset = TorchNormalXYDataset(num_samples)
if inherit:
class Sampler(torch.utils.data.RandomSampler):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
indices = list(range(len(dataset)))
return iter(indices)

def __len__(self):
return len(self.dataset)
else:
class Sampler:
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
indices = list(range(len(dataset)))
return iter(indices)

def __len__(self):
return len(self.dataset)

dl = prepare_torch_dataloader(dataset, sampler=Sampler(dataset, batch_size=4))
model = TorchNormalModel_Classification_1(10, 32)
device = [torch.device(i) for i in [0, 1]]
driver = TorchDDPDriver(model, parallel_device=device)
driver.setup()
# TODO 这里需要raise
with pytest.raises(TypeError):
dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False)
finally:
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()

Loading…
Cancel
Save