You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

meta_dataset.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # -*- coding: utf-8 -*-
  2. from abc import ABC, abstractmethod
  3. from typing import Tuple
  4. class Dataset(ABC):
  5. r"""An abstract base class for all map-style datasets.
  6. .. admonition:: Abstract methods
  7. All subclasses should overwrite these two methods:
  8. * ``__getitem__()``: fetch a data sample for a given key.
  9. * ``__len__()``: return the size of the dataset.
  10. They play roles in the data pipeline, see the description below.
  11. .. admonition:: Dataset in the Data Pipline
  12. Usually a dataset works with :class:`~.DataLoader`, :class:`~.Sampler`, :class:`~.Collator` and other components.
  13. For example, the sampler generates **indexes** of batches in advance according to the size of the dataset (calling ``__len__``),
  14. When dataloader need to yield a batch of data, pass indexes into the ``__getitem__`` method, then collate them to a batch.
  15. * Highly recommended reading :ref:`dataset-guide` for more details;
  16. * It might helpful to read the implementation of :class:`~.MNIST`, :class:`~.CIFAR10` and other existed subclass.
  17. .. warning::
  18. By default, all elements in a dataset would be :class:`numpy.ndarray`.
  19. It means that if you want to do Tensor operations, it's better to do the conversion explicitly, such as:
  20. .. code-block:: python
  21. dataset = MyCustomDataset() # A subclass of Dataset
  22. data, label = MyCustomDataset[0] # equals to MyCustomDataset.__getitem__[0]
  23. data = Tensor(data, dtype="float32") # convert to MegEngine Tensor explicitly
  24. megengine.functional.ops(data)
  25. Tensor ops on ndarray directly are undefined behaviors.
  26. """
  27. @abstractmethod
  28. def __init__(self):
  29. pass
  30. @abstractmethod
  31. def __getitem__(self, index):
  32. pass
  33. @abstractmethod
  34. def __len__(self):
  35. pass
  36. class StreamDataset(Dataset):
  37. r"""An abstract class for stream data.
  38. __iter__ method is aditionally needed.
  39. Examples:
  40. .. code-block:: python
  41. from megengine.data.dataset import StreamDataset
  42. from megengine.data.dataloader import DataLoader, get_worker_info
  43. from megengine.data.sampler import StreamSampler
  44. class MyStream(StreamDataset):
  45. def __init__(self):
  46. self.data = [iter([1, 2, 3]), iter([4, 5, 6]), iter([7, 8, 9])]
  47. def __iter__(self):
  48. worker_info = get_worker_info()
  49. data_iter = self.data[worker_info.idx]
  50. while True:
  51. yield next(data_iter)
  52. dataloader = DataLoader(
  53. dataset = MyStream(),
  54. sampler = StreamSampler(batch_size=2),
  55. num_workers=3,
  56. parallel_stream = True,
  57. )
  58. for step, data in enumerate(dataloader):
  59. print(data)
  60. """
  61. @abstractmethod
  62. def __init__(self):
  63. pass
  64. @abstractmethod
  65. def __iter__(self):
  66. pass
  67. def __getitem__(self, idx):
  68. raise AssertionError("can not get item from StreamDataset by index")
  69. def __len__(self):
  70. raise AssertionError("StreamDataset does not have length")
  71. class ArrayDataset(Dataset):
  72. r"""ArrayDataset is a dataset for numpy array data.
  73. One or more numpy arrays are needed to initiate the dataset.
  74. And the dimensions represented sample number are expected to be the same.
  75. Examples:
  76. .. code-block:: python
  77. from megengine.data.dataset import ArrayDataset
  78. from megengine.data.dataloader import DataLoader
  79. from megengine.data.sampler import SequentialSampler
  80. rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8)
  81. label = np.random.randint(0, 10, size=(sample_num,), dtype=int)
  82. dataset = ArrayDataset(rand_data, label)
  83. seque_sampler = SequentialSampler(dataset, batch_size=2)
  84. dataloader = DataLoader(
  85. dataset,
  86. sampler = seque_sampler,
  87. num_workers=3,
  88. )
  89. for step, data in enumerate(dataloader):
  90. print(data)
  91. """
  92. def __init__(self, *arrays):
  93. super().__init__()
  94. if not all(len(arrays[0]) == len(array) for array in arrays):
  95. raise ValueError("lengths of input arrays are inconsistent")
  96. self.arrays = arrays
  97. def __getitem__(self, index: int) -> Tuple:
  98. return tuple(array[index] for array in self.arrays)
  99. def __len__(self) -> int:
  100. return len(self.arrays[0])