You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

meta_dataset.py 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # -*- coding: utf-8 -*-
  2. from abc import ABC, abstractmethod
  3. from typing import Tuple
  4. class Dataset(ABC):
  5. r"""An abstract base class for all map-style datasets.
  6. .. admonition:: Abstract methods
  7. All subclasses should overwrite these two methods:
  8. * ``__getitem__()``: fetch a data sample for a given key.
  9. * ``__len__()``: return the size of the dataset.
  10. They play roles in the data pipeline, see the description below.
  11. .. admonition:: Dataset in the Data Pipline
  12. Usually a dataset works with :class:`~.DataLoader`, :class:`~.Sampler`, :class:`~.Collator` and other components.
  13. For example, the sampler generates **indexes** of batches in advance according to the size of the dataset (calling ``__len__``),
  14. When dataloader need to yield a batch of data, pass indexes into the ``__getitem__`` method, then collate them to a batch.
  15. * Highly recommended reading :ref:`dataset-guide` for more details;
  16. * It might helpful to read the implementation of :class:`~.MNIST`, :class:`~.CIFAR10` and other existed subclass.
  17. .. warning::
  18. By default, all elements in a dataset would be :class:`numpy.ndarray`.
  19. It means that if you want to do Tensor operations, it's better to do the conversion explicitly, such as:
  20. .. code-block:: python
  21. dataset = MyCustomDataset() # A subclass of Dataset
  22. data, label = MyCustomDataset[0] # equals to MyCustomDataset.__getitem__[0]
  23. data = Tensor(data, dtype="float32") # convert to MegEngine Tensor explicitly
  24. megengine.functional.ops(data)
  25. Tensor ops on ndarray directly are undefined behaviors.
  26. """
  27. @abstractmethod
  28. def __init__(self):
  29. pass
  30. @abstractmethod
  31. def __getitem__(self, index):
  32. pass
  33. @abstractmethod
  34. def __len__(self):
  35. pass
  36. class StreamDataset(Dataset):
  37. r"""An abstract class for stream data.
  38. __iter__ method is aditionally needed.
  39. """
  40. @abstractmethod
  41. def __init__(self):
  42. pass
  43. @abstractmethod
  44. def __iter__(self):
  45. pass
  46. def __getitem__(self, idx):
  47. raise AssertionError("can not get item from StreamDataset by index")
  48. def __len__(self):
  49. raise AssertionError("StreamDataset does not have length")
  50. class ArrayDataset(Dataset):
  51. r"""ArrayDataset is a dataset for numpy array data.
  52. One or more numpy arrays are needed to initiate the dataset.
  53. And the dimensions represented sample number are expected to be the same.
  54. """
  55. def __init__(self, *arrays):
  56. super().__init__()
  57. if not all(len(arrays[0]) == len(array) for array in arrays):
  58. raise ValueError("lengths of input arrays are inconsistent")
  59. self.arrays = arrays
  60. def __getitem__(self, index: int) -> Tuple:
  61. return tuple(array[index] for array in self.arrays)
  62. def __len__(self) -> int:
  63. return len(self.arrays[0])