Browse Source

refactor(mge/data): rename `MapDataset` to `Dataset`

GitOrigin-RevId: 6262561355
release-1.2
Megvii Engine Team 4 years ago
parent
commit
05c739b846
5 changed files with 26 additions and 34 deletions
  1. +14
    -16
      imperative/python/megengine/data/dataloader.py
  2. +1
    -1
      imperative/python/megengine/data/dataset/__init__.py
  3. +8
    -12
      imperative/python/megengine/data/dataset/meta_dataset.py
  4. +2
    -2
      imperative/python/megengine/data/dataset/vision/meta_vision.py
  5. +1
    -3
      imperative/python/test/unit/data/test_dataset.py

+ 14
- 16
imperative/python/megengine/data/dataloader.py View File

@@ -19,7 +19,7 @@ import numpy as np
from ..logger import get_logger
from ..random.rng import _random_seed_generator
from .collator import Collator
from .dataset import Dataset, MapDataset, StreamDataset
from .dataset import Dataset, StreamDataset
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform

@@ -88,7 +88,15 @@ class DataLoader:

self.divide = divide

if isinstance(dataset, MapDataset):
if isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else:
assert isinstance(
dataset, Dataset
), "Can not recognize this kind of dataset: %s" % type(dataset)
self.sampler = (
sampler
if sampler
@@ -97,15 +105,6 @@ class DataLoader:
assert isinstance(
self.sampler, MapSampler
), "types of dataset and sampler do not match"
elif isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)

if divide:
if self.sampler.batch_size <= self.num_workers:
@@ -140,15 +139,14 @@ class DataLoader:
return _SerialStreamDataLoaderIter(self)
else:
return _ParallelStreamDataLoaderIter(self)
elif isinstance(self.dataset, MapDataset):
else:
assert isinstance(
self.dataset, Dataset
), "Can not recognize this kind of dataset: %s" % type(self.dataset)
if not self.num_workers:
return _SerialMapDataLoaderIter(self)
else:
return _ParallelMapDataLoaderIter(self)
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(self.dataset)
)

def __len__(self):
return len(self.sampler)


+ 1
- 1
imperative/python/megengine/data/dataset/__init__.py View File

@@ -6,5 +6,5 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset
from .meta_dataset import ArrayDataset, Dataset, StreamDataset
from .vision import *

+ 8
- 12
imperative/python/megengine/data/dataset/meta_dataset.py View File

@@ -12,17 +12,7 @@ from typing import Tuple

class Dataset(ABC):
r"""
An abstract class for all Datasets.
"""

@abstractmethod
def __init__(self):
pass


class MapDataset(Dataset):
r"""
An abstract class for map data.
An abstract class for all datasets.
__getitem__ and __len__ method are aditionally needed.
"""

@@ -53,8 +43,14 @@ class StreamDataset(Dataset):
def __iter__(self):
pass

def __getitem__(self):
raise AssertionError("can not get item from StreamDataset by index")

def __len__(self):
raise AssertionError("StreamDataset does not have length")


class ArrayDataset(MapDataset):
class ArrayDataset(Dataset):
def __init__(self, *arrays):
r"""
ArrayDataset is a dataset for numpy array data, one or more numpy arrays


+ 2
- 2
imperative/python/megengine/data/dataset/vision/meta_vision.py View File

@@ -9,10 +9,10 @@
import collections.abc
import os

from ..meta_dataset import MapDataset
from ..meta_dataset import Dataset


class VisionDataset(MapDataset):
class VisionDataset(Dataset):
_repr_indent = 4

def __init__(self, root, *, order=None, supported_order=None):


+ 1
- 3
imperative/python/test/unit/data/test_dataset.py View File

@@ -12,15 +12,13 @@ import sys
import numpy as np
import pytest

from megengine.data.dataset import ArrayDataset, Dataset, MapDataset, StreamDataset
from megengine.data.dataset import ArrayDataset, Dataset, StreamDataset


def test_abstract_cls():
with pytest.raises(TypeError):
Dataset()
with pytest.raises(TypeError):
MapDataset()
with pytest.raises(TypeError):
StreamDataset()




Loading…
Cancel
Save