@@ -107,7 +107,7 @@ class Evaluator: | |||
``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论 | |||
该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``; | |||
* *use_dist_sampler* -- | |||
是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持 | |||
True / False, 是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持 | |||
分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集; | |||
* *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数; | |||
* *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数; | |||
@@ -290,9 +290,9 @@ class Trainer(TrainerEventTrigger): | |||
driver 实例的 ``model_device`` 才会为 None; | |||
3. 对于 paddle,该参数无效; | |||
* *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch | |||
* *use_dist_sampler* -- True / False, 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch | |||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | |||
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``; | |||
* *evaluate_use_dist_sampler* -- True / False, 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``; | |||
不传入该值时,该值与 ``use_dist_sampler`` 参数保持一致; | |||
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||
@@ -565,6 +565,13 @@ class TorchDDPDriver(TorchDriver): | |||
) | |||
return replace_sampler(dataloader, sampler) | |||
else: | |||
if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {torch.utils.data.RandomSampler, | |||
torch.utils.data.SequentialSampler}): | |||
raise TypeError("Using customized ``batch_sampler`` or ``sampler`` with 'DDP' may cause unseen problems, cause" | |||
"we will substitute your dataloader's sampler into our ``fastNLP.RandomSampler``. You should make" | |||
"your customized sampler being able to be used in distributed setting before you initialize ``Trainer`` by yourself," | |||
"and then set the parameter ``use_dist_sampler`` of ``Trainer`` to ``False``.") | |||
sampler = RandomSampler( | |||
dataset=args.dataset, | |||
shuffle=args.shuffle, | |||
@@ -582,6 +589,7 @@ class TorchDDPDriver(TorchDriver): | |||
if isinstance(args.sampler, ReproducibleSampler): | |||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||
# todo same as dist | |||
sampler = UnrepeatedSequentialSampler( | |||
dataset=args.dataset | |||
) | |||
@@ -14,7 +14,7 @@ from fastNLP.envs import ( | |||
FASTNLP_BACKEND_LAUNCH, | |||
FASTNLP_GLOBAL_SEED, | |||
) | |||
from fastNLP.core.samplers import re_instantiate_sampler | |||
from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler | |||
from fastNLP.core.utils import auto_param_call | |||
from fastNLP.core.log import logger | |||
@@ -23,7 +23,6 @@ if _NEED_IMPORT_TORCH: | |||
# import torch.nn as nn | |||
from torch.nn import Module | |||
from torch.utils.data import DataLoader, BatchSampler | |||
from torch.utils.data.sampler import Sampler | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||
@@ -201,7 +200,10 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||
non_default_params.add("dataset") | |||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||
reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) | |||
batch_sampler = getattr(dataloader, "batch_sampler") | |||
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | |||
raise RuntimeError("It should not be running here, please report a bug to us.") | |||
required_args = { | |||
p.name | |||
@@ -243,28 +245,6 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||
return type(dataloader)(**reconstruct_args) | |||
def _dataloader_init_kwargs_resolve_sampler( | |||
dataloader: "DataLoader", sampler: Optional["Sampler"] | |||
) -> Dict[str, Any]: | |||
r""" | |||
此函数用于处理与 DataLoader 关联的采样器、batch_sampler 参数重新实例化; | |||
""" | |||
batch_sampler = getattr(dataloader, "batch_sampler") | |||
# checking the batch sampler type is different than PyTorch default. | |||
if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler): | |||
batch_sampler = re_instantiate_sampler(batch_sampler) | |||
return { | |||
"sampler": None, | |||
"shuffle": False, | |||
"batch_sampler": batch_sampler, | |||
"batch_size": 1, | |||
"drop_last": False, | |||
} | |||
return {"sampler": sampler, "shuffle": False, "batch_sampler": None} | |||
def replace_batch_sampler(dataloader, new_batch_sampler): | |||
r""" | |||
替换一个 dataloader 的 batch_sampler; | |||