|
|
@@ -19,7 +19,7 @@ import numpy as np |
|
|
|
from ..logger import get_logger |
|
|
|
from ..random.rng import _random_seed_generator |
|
|
|
from .collator import Collator |
|
|
|
from .dataset import Dataset, MapDataset, StreamDataset |
|
|
|
from .dataset import Dataset, StreamDataset |
|
|
|
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler |
|
|
|
from .transform import PseudoTransform, Transform |
|
|
|
|
|
|
@@ -88,7 +88,15 @@ class DataLoader: |
|
|
|
|
|
|
|
self.divide = divide |
|
|
|
|
|
|
|
if isinstance(dataset, MapDataset): |
|
|
|
if isinstance(dataset, StreamDataset): |
|
|
|
self.sampler = sampler if sampler else StreamSampler(batch_size=1) |
|
|
|
assert isinstance( |
|
|
|
self.sampler, StreamSampler |
|
|
|
), "types of dataset and sampler do not match" |
|
|
|
else: |
|
|
|
assert isinstance( |
|
|
|
dataset, Dataset |
|
|
|
), "Can not recognize this kind of dataset: %s" % type(dataset) |
|
|
|
self.sampler = ( |
|
|
|
sampler |
|
|
|
if sampler |
|
|
@@ -97,15 +105,6 @@ class DataLoader: |
|
|
|
assert isinstance( |
|
|
|
self.sampler, MapSampler |
|
|
|
), "types of dataset and sampler do not match" |
|
|
|
elif isinstance(dataset, StreamDataset): |
|
|
|
self.sampler = sampler if sampler else StreamSampler(batch_size=1) |
|
|
|
assert isinstance( |
|
|
|
self.sampler, StreamSampler |
|
|
|
), "types of dataset and sampler do not match" |
|
|
|
else: |
|
|
|
raise TypeError( |
|
|
|
"can not recognize this kind of dataset: %s" % type(dataset) |
|
|
|
) |
|
|
|
|
|
|
|
if divide: |
|
|
|
if self.sampler.batch_size <= self.num_workers: |
|
|
@@ -140,15 +139,14 @@ class DataLoader: |
|
|
|
return _SerialStreamDataLoaderIter(self) |
|
|
|
else: |
|
|
|
return _ParallelStreamDataLoaderIter(self) |
|
|
|
elif isinstance(self.dataset, MapDataset): |
|
|
|
else: |
|
|
|
assert isinstance( |
|
|
|
self.dataset, Dataset |
|
|
|
), "Can not recognize this kind of dataset: %s" % type(self.dataset) |
|
|
|
if not self.num_workers: |
|
|
|
return _SerialMapDataLoaderIter(self) |
|
|
|
else: |
|
|
|
return _ParallelMapDataLoaderIter(self) |
|
|
|
else: |
|
|
|
raise TypeError( |
|
|
|
"can not recognize this kind of dataset: %s" % type(self.dataset) |
|
|
|
) |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.sampler) |
|
|
|