diff --git a/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py b/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py index 29b0cd0b..1b77be77 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py +++ b/fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py @@ -5,6 +5,7 @@ __all__ = [ from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping import numpy as np +from pkg_resources import parse_version from fastNLP.core.dataset import DataSet, Instance from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler @@ -12,6 +13,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.core.collators import Collator if _NEED_IMPORT_TORCH: + from torch import __version__ as torchversion from torch.utils.data import DataLoader, Sampler else: from fastNLP.core.utils.dummy_class import DummyClass as DataLoader @@ -213,13 +215,21 @@ class MixDataLoader(DataLoader): else: raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler") - super(MixDataLoader, self).__init__( - _MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, - batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, - pin_memory=pin_memory, drop_last=False, timeout=0, - worker_init_fn=None, multiprocessing_context=None, generator=None, - prefetch_factor=2, persistent_workers=False - ) + if parse_version(torchversion) >= parse_version('1.7'): + super(MixDataLoader, self).__init__( + _MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, + pin_memory=pin_memory, drop_last=False, timeout=0, + worker_init_fn=None, multiprocessing_context=None, generator=None, + prefetch_factor=2, persistent_workers=False + ) + else: + super(MixDataLoader, self).__init__( + _MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, + pin_memory=pin_memory, drop_last=False, timeout=0, + worker_init_fn=None, multiprocessing_context=None, generator=None, + ) def __iter__(self): return super().__iter__()