Browse Source

refactor(mge/data): rename `MapDataset` to `Dataset`

GitOrigin-RevId: 6262561355
release-1.2
Megvii Engine Team 4 years ago
parent
commit
05c739b846
5 changed files with 26 additions and 34 deletions
  1. +14
    -16
      imperative/python/megengine/data/dataloader.py
  2. +1
    -1
      imperative/python/megengine/data/dataset/__init__.py
  3. +8
    -12
      imperative/python/megengine/data/dataset/meta_dataset.py
  4. +2
    -2
      imperative/python/megengine/data/dataset/vision/meta_vision.py
  5. +1
    -3
      imperative/python/test/unit/data/test_dataset.py

+ 14
- 16
imperative/python/megengine/data/dataloader.py View File

@@ -19,7 +19,7 @@ import numpy as np
from ..logger import get_logger from ..logger import get_logger
from ..random.rng import _random_seed_generator from ..random.rng import _random_seed_generator
from .collator import Collator from .collator import Collator
from .dataset import Dataset, MapDataset, StreamDataset
from .dataset import Dataset, StreamDataset
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform from .transform import PseudoTransform, Transform


@@ -88,7 +88,15 @@ class DataLoader:


self.divide = divide self.divide = divide


if isinstance(dataset, MapDataset):
if isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else:
assert isinstance(
dataset, Dataset
), "Can not recognize this kind of dataset: %s" % type(dataset)
self.sampler = ( self.sampler = (
sampler sampler
if sampler if sampler
@@ -97,15 +105,6 @@ class DataLoader:
assert isinstance( assert isinstance(
self.sampler, MapSampler self.sampler, MapSampler
), "types of dataset and sampler do not match" ), "types of dataset and sampler do not match"
elif isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)


if divide: if divide:
if self.sampler.batch_size <= self.num_workers: if self.sampler.batch_size <= self.num_workers:
@@ -140,15 +139,14 @@ class DataLoader:
return _SerialStreamDataLoaderIter(self) return _SerialStreamDataLoaderIter(self)
else: else:
return _ParallelStreamDataLoaderIter(self) return _ParallelStreamDataLoaderIter(self)
elif isinstance(self.dataset, MapDataset):
else:
assert isinstance(
self.dataset, Dataset
), "Can not recognize this kind of dataset: %s" % type(self.dataset)
if not self.num_workers: if not self.num_workers:
return _SerialMapDataLoaderIter(self) return _SerialMapDataLoaderIter(self)
else: else:
return _ParallelMapDataLoaderIter(self) return _ParallelMapDataLoaderIter(self)
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(self.dataset)
)


def __len__(self): def __len__(self):
return len(self.sampler) return len(self.sampler)


+ 1
- 1
imperative/python/megengine/data/dataset/__init__.py View File

@@ -6,5 +6,5 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset
from .meta_dataset import ArrayDataset, Dataset, StreamDataset
from .vision import * from .vision import *

+ 8
- 12
imperative/python/megengine/data/dataset/meta_dataset.py View File

@@ -12,17 +12,7 @@ from typing import Tuple


class Dataset(ABC): class Dataset(ABC):
r""" r"""
An abstract class for all Datasets.
"""

@abstractmethod
def __init__(self):
pass


class MapDataset(Dataset):
r"""
An abstract class for map data.
An abstract class for all datasets.
__getitem__ and __len__ method are aditionally needed. __getitem__ and __len__ method are aditionally needed.
""" """


@@ -53,8 +43,14 @@ class StreamDataset(Dataset):
def __iter__(self): def __iter__(self):
pass pass


def __getitem__(self):
raise AssertionError("can not get item from StreamDataset by index")

def __len__(self):
raise AssertionError("StreamDataset does not have length")



class ArrayDataset(MapDataset):
class ArrayDataset(Dataset):
def __init__(self, *arrays): def __init__(self, *arrays):
r""" r"""
ArrayDataset is a dataset for numpy array data, one or more numpy arrays ArrayDataset is a dataset for numpy array data, one or more numpy arrays


+ 2
- 2
imperative/python/megengine/data/dataset/vision/meta_vision.py View File

@@ -9,10 +9,10 @@
import collections.abc import collections.abc
import os import os


from ..meta_dataset import MapDataset
from ..meta_dataset import Dataset




class VisionDataset(MapDataset):
class VisionDataset(Dataset):
_repr_indent = 4 _repr_indent = 4


def __init__(self, root, *, order=None, supported_order=None): def __init__(self, root, *, order=None, supported_order=None):


+ 1
- 3
imperative/python/test/unit/data/test_dataset.py View File

@@ -12,15 +12,13 @@ import sys
import numpy as np import numpy as np
import pytest import pytest


from megengine.data.dataset import ArrayDataset, Dataset, MapDataset, StreamDataset
from megengine.data.dataset import ArrayDataset, Dataset, StreamDataset




def test_abstract_cls(): def test_abstract_cls():
with pytest.raises(TypeError): with pytest.raises(TypeError):
Dataset() Dataset()
with pytest.raises(TypeError): with pytest.raises(TypeError):
MapDataset()
with pytest.raises(TypeError):
StreamDataset() StreamDataset()






Loading…
Cancel
Save