@@ -107,8 +107,11 @@ class Evaluator: | |||
``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论 | |||
该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``; | |||
* *use_dist_sampler* -- | |||
True / False, 是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持 | |||
分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集; | |||
表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为 | |||
分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader | |||
的 sampler 为 (a) 深度学习框架自带的默认 sampler ; (b) fastNLP 的 Sampler 等,则将替换为 | |||
:class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以 | |||
用到的数据。如果不是以上两类 sampler ,fastNLP 将报错。 | |||
* *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数; | |||
* *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数; | |||
* *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。 | |||
@@ -65,6 +65,13 @@ class Trainer(TrainerEventTrigger): | |||
您传入的 ``Driver`` 实例中的模型; | |||
:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; | |||
.. warning:: | |||
当使用分布式训练时, ``fastNLP`` 会默认将 ``dataloader`` 中的 ``Sampler`` 进行处理,以使得在一个 ``epcoh`` 中,不同卡 | |||
用以训练的数据是不重叠的。如果你对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由 | |||
你自身保证每张卡上所使用的数据是不同的。 | |||
:param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; | |||
:param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 `torch.distributed.launch/run` 启动时可以为 None, | |||
此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间 | |||
@@ -93,9 +100,9 @@ class Trainer(TrainerEventTrigger): | |||
.. warning:: | |||
注意参数 ``device`` 仅当您通过 pytorch 或者其它训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``! | |||
注意参数 ``device`` 仅当您通过训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``! | |||
例如,当您使用:: | |||
例如,在 pytorch 中,当您使用:: | |||
python -m torch.distributed.launch --nproc_per_node 2 train.py | |||
@@ -290,16 +297,22 @@ class Trainer(TrainerEventTrigger): | |||
driver 实例的 ``model_device`` 才会为 None; | |||
3. 对于 paddle,该参数无效; | |||
* *use_dist_sampler* -- True / False, 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch | |||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | |||
* *evaluate_use_dist_sampler* -- True / False, 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``; | |||
不传入该值时,该值与 ``use_dist_sampler`` 参数保持一致; | |||
* *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch | |||
内所有卡的 sample 加起来为一整个数据集的 sample,同时为了保证所有卡上拥有相同数量的 sample ,有的卡上可能会有重复的 sample ,例如 | |||
8卡训练,只有9个sample,如果batch_size为1,那么第二个batch时,有7张卡将没有 sample 可用,因此只有重复使用 sample 来 pad 到第二个 | |||
batch 中。如果不希望 fastNLP 对 dataloader 的sampler 做特殊设置,请将该值设置为 False ,若确实需要分布式的训练,请在 Trainer 外 | |||
对 train_dataloader 做的数据做特殊处理使得其在不同的卡之间 sample 是 | |||
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为 | |||
evaluate 时使用的分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader | |||
的 sampler 为 (a) 深度学习框架自带的默认 sampler ; (b) fastNLP 的 Sampler 等,则将替换为 | |||
:class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以 | |||
用到的数据。 | |||
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||
注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; | |||
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 :class:`~.fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback`等对象, | |||
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback`等对象, | |||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 :class:`~fastNLP.RichCallback`,否则使用 :class:`~fastNLP.RawTextCallback` 对象。如果 | |||
需要定制 progress bar 的参数,例如打印频率等,可以传入 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback` 等对象。 | |||
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 | |||
@@ -422,8 +422,7 @@ class PaddleFleetDriver(PaddleDriver): | |||
# trainer, evaluator | |||
if dist is None: | |||
if reproducible: | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | |||
"control.") | |||
raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.") | |||
else: | |||
args = self.get_dataloader_args(dataloader) | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
@@ -140,6 +140,9 @@ if _NEED_IMPORT_TORCH: | |||
import torch.distributed as dist | |||
from torch.nn.parallel import DistributedDataParallel | |||
from torch.utils.data import BatchSampler | |||
from torch.utils.data import RandomSampler as TorchRandomSampler | |||
from torch.utils.data import SequentialSampler as TorchSequentialSampler | |||
from torch.utils.data import BatchSampler as TorchBatchSampler | |||
__all__ = [ | |||
'TorchDDPDriver' | |||
@@ -159,6 +162,7 @@ from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, Unrepeated | |||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | |||
from .utils import _check_dataloader_args_for_distributed | |||
class TorchDDPDriver(TorchDriver): | |||
@@ -535,8 +539,7 @@ class TorchDDPDriver(TorchDriver): | |||
# trainer, evaluator | |||
if dist is None: | |||
if reproducible: | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | |||
"control.") | |||
raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.") | |||
else: | |||
args = self.get_dataloader_args(dataloader) | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
@@ -565,13 +568,7 @@ class TorchDDPDriver(TorchDriver): | |||
) | |||
return replace_sampler(dataloader, sampler) | |||
else: | |||
if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {torch.utils.data.RandomSampler, | |||
torch.utils.data.SequentialSampler}): | |||
raise TypeError("Using customized ``batch_sampler`` or ``sampler`` with 'DDP' may cause unseen problems, cause" | |||
"we will substitute your dataloader's sampler into our ``fastNLP.RandomSampler``. You should make" | |||
"your customized sampler being able to be used in distributed setting before you initialize ``Trainer`` by yourself," | |||
"and then set the parameter ``use_dist_sampler`` of ``Trainer`` to ``False``.") | |||
_check_dataloader_args_for_distributed(args, controller='Trainer') | |||
sampler = RandomSampler( | |||
dataset=args.dataset, | |||
shuffle=args.shuffle, | |||
@@ -589,7 +586,7 @@ class TorchDDPDriver(TorchDriver): | |||
if isinstance(args.sampler, ReproducibleSampler): | |||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||
# todo same as dist | |||
_check_dataloader_args_for_distributed(args, controller='Evaluator') | |||
sampler = UnrepeatedSequentialSampler( | |||
dataset=args.dataset | |||
) | |||
@@ -8,6 +8,7 @@ if _NEED_IMPORT_TORCH: | |||
from torch.nn.parallel import DistributedDataParallel | |||
from torch.utils.data import RandomSampler as TorchRandomSampler | |||
from torch.utils.data import SequentialSampler as TorchSequentialSampler | |||
from torch.utils.data import BatchSampler as TorchBatchSampler | |||
__all__ = [ | |||
'TorchSingleDriver' | |||
@@ -123,19 +124,20 @@ class TorchSingleDriver(TorchDriver): | |||
return replace_sampler(dataloader, sampler) | |||
if reproducible: | |||
if isinstance(args.sampler, TorchRandomSampler): | |||
if getattr(args.sampler, '_num_samples', None) is None \ | |||
and getattr(args.sampler, 'replacements', False) is False \ | |||
and getattr(args.sampler, 'generator', None) is None: | |||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||
if type(args.batch_sampler) is TorchBatchSampler: | |||
if type(args.sampler) is TorchRandomSampler: | |||
if getattr(args.sampler, '_num_samples', None) is None \ | |||
and getattr(args.sampler, 'replacements', False) is False \ | |||
and getattr(args.sampler, 'generator', None) is None: | |||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||
return replace_sampler(dataloader, sampler) | |||
elif type(args.sampler) is TorchSequentialSampler: | |||
# 需要替换为不要 shuffle 的。 | |||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") | |||
return replace_sampler(dataloader, sampler) | |||
elif isinstance(args.sampler, TorchSequentialSampler): | |||
# 需要替换为不要 shuffle 的。 | |||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") | |||
return replace_sampler(dataloader, sampler) | |||
batch_sampler = ReproduceBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
@@ -22,7 +22,11 @@ if _NEED_IMPORT_TORCH: | |||
import torch | |||
# import torch.nn as nn | |||
from torch.nn import Module | |||
from torch.utils.data import DataLoader, BatchSampler | |||
from torch.utils.data import DataLoader | |||
from torch.utils.data import RandomSampler as TorchRandomSampler | |||
from torch.utils.data import SequentialSampler as TorchSequentialSampler | |||
from torch.utils.data import BatchSampler as TorchBatchSampler | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||
@@ -200,6 +204,7 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||
non_default_params.add("dataset") | |||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) | |||
batch_sampler = getattr(dataloader, "batch_sampler") | |||
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | |||
@@ -277,5 +282,13 @@ def optimizer_state_to_device(state, device): | |||
return new_state | |||
def _check_dataloader_args_for_distributed(args, controller='Trainer'): | |||
if type(args.batch_sampler) is not TorchBatchSampler and (type(args.sampler) not in {TorchRandomSampler, | |||
TorchSequentialSampler}): | |||
mode = 'training' if controller == 'Trainer' else 'evaluation' | |||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | |||
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | |||
f"``{substitution}``. The customized sampler should set for distributed running " | |||
f"before initializing ``{controller}`` , and then set the " | |||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") |
@@ -12,8 +12,9 @@ from fastNLP.core.samplers import ( | |||
) | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset | |||
from tests.helpers.utils import magic_argv_env_context | |||
from tests.helpers.utils import magic_argv_env_context, recover_logger, Capturing | |||
from fastNLP.envs.distributed import rank_zero_rm | |||
from fastNLP import logger | |||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
@@ -936,12 +937,118 @@ def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible=T | |||
for i in range(1, len(data)-1): | |||
flags.append(data[i]<data[i-1]) | |||
assert all(flags) is False | |||
datas = fastnlp_torch_all_gather(data) | |||
if drop_last: | |||
assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 | |||
if dist.is_initialized(): | |||
datas = fastnlp_torch_all_gather(data) | |||
if drop_last: | |||
assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 | |||
else: | |||
assert len(set(datas[0] + datas[1])) == num_samples | |||
finally: | |||
if dist.is_initialized(): | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
@pytest.mark.torch | |||
@magic_argv_env_context | |||
@recover_logger | |||
@pytest.mark.parametrize("inherit", ([True, False])) | |||
def test_customized_batch_sampler_dataloader(inherit): | |||
try: | |||
logger.set_stdout('raw', level='info') | |||
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 | |||
num_samples = 10 | |||
dataset = TorchNormalXYDataset(num_samples) | |||
if inherit: | |||
class BatchSampler(torch.utils.data.BatchSampler): | |||
def __init__(self, dataset, batch_size): | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
def __iter__(self): | |||
indices = list(range(len(dataset))) | |||
for i in range(len(self)): | |||
start = i * self.batch_size | |||
end = (i + 1) * self.batch_size | |||
return indices[start:end] | |||
def __len__(self): | |||
return (len(self.dataset)+self.batch_size-1)//self.batch_size | |||
else: | |||
assert len(set(datas[0] + datas[1])) == num_samples | |||
class BatchSampler: | |||
def __init__(self, dataset, batch_size): | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
def __iter__(self): | |||
indices = list(range(len(dataset))) | |||
for i in range(len(self)): | |||
start = i * self.batch_size | |||
end = (i + 1) * self.batch_size | |||
return indices[start:end] | |||
def __len__(self): | |||
return (len(self.dataset)+self.batch_size-1)//self.batch_size | |||
dl = prepare_torch_dataloader(dataset, batch_sampler=BatchSampler(dataset, batch_size=4)) | |||
model = TorchNormalModel_Classification_1(10, 32) | |||
device = [torch.device(i) for i in [0, 1]] | |||
driver = TorchDDPDriver(model, parallel_device=device) | |||
driver.setup() | |||
# TODO 这里需要raise | |||
with pytest.raises(TypeError): | |||
dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) | |||
finally: | |||
if dist.is_initialized(): | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
@pytest.mark.torch | |||
@magic_argv_env_context | |||
@recover_logger | |||
@pytest.mark.parametrize("inherit", ([True, False])) | |||
def test_customized_sampler_dataloader(inherit): | |||
try: | |||
logger.set_stdout('raw', level='info') | |||
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 | |||
num_samples = 10 | |||
dataset = TorchNormalXYDataset(num_samples) | |||
if inherit: | |||
class Sampler(torch.utils.data.RandomSampler): | |||
def __init__(self, dataset, batch_size): | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
def __iter__(self): | |||
indices = list(range(len(dataset))) | |||
return iter(indices) | |||
def __len__(self): | |||
return len(self.dataset) | |||
else: | |||
class Sampler: | |||
def __init__(self, dataset, batch_size): | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
def __iter__(self): | |||
indices = list(range(len(dataset))) | |||
return iter(indices) | |||
def __len__(self): | |||
return len(self.dataset) | |||
dl = prepare_torch_dataloader(dataset, sampler=Sampler(dataset, batch_size=4)) | |||
model = TorchNormalModel_Classification_1(10, 32) | |||
device = [torch.device(i) for i in [0, 1]] | |||
driver = TorchDDPDriver(model, parallel_device=device) | |||
driver.setup() | |||
# TODO 这里需要raise | |||
with pytest.raises(TypeError): | |||
dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) | |||
finally: | |||
if dist.is_initialized(): | |||
dist.barrier() | |||
dist.destroy_process_group() |