From 296c1acc31c4d2f8fab292e32bb0cce7ecdab617 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 25 May 2022 08:21:33 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dmixdataloader=E5=9C=A8torch1.?= =?UTF-8?q?6=E4=B8=8B=E5=8F=82=E6=95=B0=E4=B8=8D=E5=8C=B9=E9=85=8D?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dataloaders/torch_dataloader/mix_dataloader.py | 24 +++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py b/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py index 29b0cd0b..1b77be77 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py +++ b/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py @@ -5,6 +5,7 @@ __all__ = [ from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping import numpy as np +from pkg_resources import parse_version from fastNLP.core.dataset import DataSet, Instance from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler @@ -12,6 +13,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.core.collators import Collator if _NEED_IMPORT_TORCH: + from torch import __version__ as torchversion from torch.utils.data import DataLoader, Sampler else: from fastNLP.core.utils.dummy_class import DummyClass as DataLoader @@ -213,13 +215,21 @@ class MixDataLoader(DataLoader): else: raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler") - super(MixDataLoader, self).__init__( - _MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, - batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, - pin_memory=pin_memory, drop_last=False, timeout=0, - worker_init_fn=None, multiprocessing_context=None, generator=None, - prefetch_factor=2, persistent_workers=False - ) + if parse_version(torchversion) >= parse_version('1.7'): + super(MixDataLoader, self).__init__( + _MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, + pin_memory=pin_memory, drop_last=False, timeout=0, + worker_init_fn=None, multiprocessing_context=None, generator=None, + prefetch_factor=2, persistent_workers=False + ) + else: + super(MixDataLoader, self).__init__( + _MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, + pin_memory=pin_memory, drop_last=False, timeout=0, + worker_init_fn=None, multiprocessing_context=None, generator=None, + ) def __iter__(self): return super().__iter__()