|
|
@@ -4,9 +4,41 @@ from typing import Tuple |
|
|
|
|
|
|
|
|
|
|
|
class Dataset(ABC): |
|
|
|
r"""An abstract base class for all datasets. |
|
|
|
r"""An abstract base class for all map-style datasets. |
|
|
|
|
|
|
|
__getitem__ and __len__ method are aditionally needed. |
|
|
|
.. admonition:: Abstract methods |
|
|
|
|
|
|
|
All subclasses should overwrite these two methods: |
|
|
|
|
|
|
|
* ``__getitem__()``: fetch a data sample for a given key. |
|
|
|
* ``__len__()``: return the size of the dataset. |
|
|
|
|
|
|
|
They play roles in the data pipeline, see the description below. |
|
|
|
|
|
|
|
.. admonition:: Dataset in the Data Pipline |
|
|
|
|
|
|
|
Usually a dataset works with :class:`~.DataLoader`, :class:`~.Sampler`, :class:`~.Collator` and other components. |
|
|
|
|
|
|
|
For example, the sampler generates **indexes** of batches in advance according to the size of the dataset (calling ``__len__``), |
|
|
|
When dataloader need to yield a batch of data, pass indexes into the ``__getitem__`` method, then collate them to a batch. |
|
|
|
|
|
|
|
* Highly recommended reading :ref:`dataset-guide` for more details; |
|
|
|
* It might helpful to read the implementation of :class:`~.MNIST`, :class:`~.CIFAR10` and other existed subclass. |
|
|
|
|
|
|
|
.. warning:: |
|
|
|
|
|
|
|
By default, all elements in a dataset would be :class:`numpy.ndarray`. |
|
|
|
It means that if you want to do Tensor operations, it's better to do the conversion explicitly, such as: |
|
|
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
|
|
dataset = MyCustomDataset() # A subclass of Dataset |
|
|
|
data, label = MyCustomDataset[0] # equals to MyCustomDataset.__getitem__[0] |
|
|
|
data = Tensor(data, dtype="float32") # convert to MegEngine Tensor explicitly |
|
|
|
|
|
|
|
megengine.functional.ops(data) |
|
|
|
|
|
|
|
Tensor ops on ndarray directly are undefined behaviors. |
|
|
|
""" |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|