GitOrigin-RevId: 6262561355
release-1.2
@@ -19,7 +19,7 @@ import numpy as np | |||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..random.rng import _random_seed_generator | from ..random.rng import _random_seed_generator | ||||
from .collator import Collator | from .collator import Collator | ||||
from .dataset import Dataset, MapDataset, StreamDataset | |||||
from .dataset import Dataset, StreamDataset | |||||
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler | from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler | ||||
from .transform import PseudoTransform, Transform | from .transform import PseudoTransform, Transform | ||||
@@ -88,7 +88,15 @@ class DataLoader: | |||||
self.divide = divide | self.divide = divide | ||||
if isinstance(dataset, MapDataset): | |||||
if isinstance(dataset, StreamDataset): | |||||
self.sampler = sampler if sampler else StreamSampler(batch_size=1) | |||||
assert isinstance( | |||||
self.sampler, StreamSampler | |||||
), "types of dataset and sampler do not match" | |||||
else: | |||||
assert isinstance( | |||||
dataset, Dataset | |||||
), "Can not recognize this kind of dataset: %s" % type(dataset) | |||||
self.sampler = ( | self.sampler = ( | ||||
sampler | sampler | ||||
if sampler | if sampler | ||||
@@ -97,15 +105,6 @@ class DataLoader: | |||||
assert isinstance( | assert isinstance( | ||||
self.sampler, MapSampler | self.sampler, MapSampler | ||||
), "types of dataset and sampler do not match" | ), "types of dataset and sampler do not match" | ||||
elif isinstance(dataset, StreamDataset): | |||||
self.sampler = sampler if sampler else StreamSampler(batch_size=1) | |||||
assert isinstance( | |||||
self.sampler, StreamSampler | |||||
), "types of dataset and sampler do not match" | |||||
else: | |||||
raise TypeError( | |||||
"can not recognize this kind of dataset: %s" % type(dataset) | |||||
) | |||||
if divide: | if divide: | ||||
if self.sampler.batch_size <= self.num_workers: | if self.sampler.batch_size <= self.num_workers: | ||||
@@ -140,15 +139,14 @@ class DataLoader: | |||||
return _SerialStreamDataLoaderIter(self) | return _SerialStreamDataLoaderIter(self) | ||||
else: | else: | ||||
return _ParallelStreamDataLoaderIter(self) | return _ParallelStreamDataLoaderIter(self) | ||||
elif isinstance(self.dataset, MapDataset): | |||||
else: | |||||
assert isinstance( | |||||
self.dataset, Dataset | |||||
), "Can not recognize this kind of dataset: %s" % type(self.dataset) | |||||
if not self.num_workers: | if not self.num_workers: | ||||
return _SerialMapDataLoaderIter(self) | return _SerialMapDataLoaderIter(self) | ||||
else: | else: | ||||
return _ParallelMapDataLoaderIter(self) | return _ParallelMapDataLoaderIter(self) | ||||
else: | |||||
raise TypeError( | |||||
"can not recognize this kind of dataset: %s" % type(self.dataset) | |||||
) | |||||
def __len__(self): | def __len__(self): | ||||
return len(self.sampler) | return len(self.sampler) | ||||
@@ -6,5 +6,5 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset | |||||
from .meta_dataset import ArrayDataset, Dataset, StreamDataset | |||||
from .vision import * | from .vision import * |
@@ -12,17 +12,7 @@ from typing import Tuple | |||||
class Dataset(ABC): | class Dataset(ABC): | ||||
r""" | r""" | ||||
An abstract class for all Datasets. | |||||
""" | |||||
@abstractmethod | |||||
def __init__(self): | |||||
pass | |||||
class MapDataset(Dataset): | |||||
r""" | |||||
An abstract class for map data. | |||||
An abstract class for all datasets. | |||||
__getitem__ and __len__ method are aditionally needed. | __getitem__ and __len__ method are aditionally needed. | ||||
""" | """ | ||||
@@ -53,8 +43,14 @@ class StreamDataset(Dataset): | |||||
def __iter__(self): | def __iter__(self): | ||||
pass | pass | ||||
def __getitem__(self): | |||||
raise AssertionError("can not get item from StreamDataset by index") | |||||
def __len__(self): | |||||
raise AssertionError("StreamDataset does not have length") | |||||
class ArrayDataset(MapDataset): | |||||
class ArrayDataset(Dataset): | |||||
def __init__(self, *arrays): | def __init__(self, *arrays): | ||||
r""" | r""" | ||||
ArrayDataset is a dataset for numpy array data, one or more numpy arrays | ArrayDataset is a dataset for numpy array data, one or more numpy arrays | ||||
@@ -9,10 +9,10 @@ | |||||
import collections.abc | import collections.abc | ||||
import os | import os | ||||
from ..meta_dataset import MapDataset | |||||
from ..meta_dataset import Dataset | |||||
class VisionDataset(MapDataset): | |||||
class VisionDataset(Dataset): | |||||
_repr_indent = 4 | _repr_indent = 4 | ||||
def __init__(self, root, *, order=None, supported_order=None): | def __init__(self, root, *, order=None, supported_order=None): | ||||
@@ -12,15 +12,13 @@ import sys | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
from megengine.data.dataset import ArrayDataset, Dataset, MapDataset, StreamDataset | |||||
from megengine.data.dataset import ArrayDataset, Dataset, StreamDataset | |||||
def test_abstract_cls(): | def test_abstract_cls(): | ||||
with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
Dataset() | Dataset() | ||||
with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
MapDataset() | |||||
with pytest.raises(TypeError): | |||||
StreamDataset() | StreamDataset() | ||||