|
|
@@ -5,6 +5,7 @@ __all__ = [ |
|
|
|
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from pkg_resources import parse_version |
|
|
|
|
|
|
|
from fastNLP.core.dataset import DataSet, Instance |
|
|
|
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler |
|
|
@@ -12,6 +13,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH |
|
|
|
from fastNLP.core.collators import Collator |
|
|
|
|
|
|
|
if _NEED_IMPORT_TORCH: |
|
|
|
from torch import __version__ as torchversion |
|
|
|
from torch.utils.data import DataLoader, Sampler |
|
|
|
else: |
|
|
|
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader |
|
|
@@ -213,13 +215,21 @@ class MixDataLoader(DataLoader): |
|
|
|
else: |
|
|
|
raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler") |
|
|
|
|
|
|
|
super(MixDataLoader, self).__init__( |
|
|
|
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, |
|
|
|
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, |
|
|
|
pin_memory=pin_memory, drop_last=False, timeout=0, |
|
|
|
worker_init_fn=None, multiprocessing_context=None, generator=None, |
|
|
|
prefetch_factor=2, persistent_workers=False |
|
|
|
) |
|
|
|
if parse_version(torchversion) >= parse_version('1.7'): |
|
|
|
super(MixDataLoader, self).__init__( |
|
|
|
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, |
|
|
|
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, |
|
|
|
pin_memory=pin_memory, drop_last=False, timeout=0, |
|
|
|
worker_init_fn=None, multiprocessing_context=None, generator=None, |
|
|
|
prefetch_factor=2, persistent_workers=False |
|
|
|
) |
|
|
|
else: |
|
|
|
super(MixDataLoader, self).__init__( |
|
|
|
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, |
|
|
|
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, |
|
|
|
pin_memory=pin_memory, drop_last=False, timeout=0, |
|
|
|
worker_init_fn=None, multiprocessing_context=None, generator=None, |
|
|
|
) |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
return super().__iter__() |