Browse Source

修复mixdataloader在torch1.6下参数不匹配的问题

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
296c1acc31
1 changed files with 17 additions and 7 deletions
  1. +17
    -7
      fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py

+ 17
- 7
fastNLP/core/dataloaders/torch_dataloader/mix_dataloader.py View File

@@ -5,6 +5,7 @@ __all__ = [
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping

import numpy as np
from pkg_resources import parse_version

from fastNLP.core.dataset import DataSet, Instance
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler
@@ -12,6 +13,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.collators import Collator

if _NEED_IMPORT_TORCH:
from torch import __version__ as torchversion
from torch.utils.data import DataLoader, Sampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
@@ -213,13 +215,21 @@ class MixDataLoader(DataLoader):
else:
raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler")

super(MixDataLoader, self).__init__(
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None, generator=None,
prefetch_factor=2, persistent_workers=False
)
if parse_version(torchversion) >= parse_version('1.7'):
super(MixDataLoader, self).__init__(
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None, generator=None,
prefetch_factor=2, persistent_workers=False
)
else:
super(MixDataLoader, self).__init__(
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None, generator=None,
)

def __iter__(self):
return super().__iter__()

Loading…
Cancel
Save