diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index 422b14a1..70aea9e7 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -68,6 +68,10 @@ class DataLoader: batch from workers. Default: 0 preload: whether to enable the preloading strategy of the dataloader. When enabling, the dataloader will preload one batch to the device memory to speed up the whole training process. + parallel_stream: whether to splitting workload across all workers when dataset is streamdataset and num_workers > 0. + When enabling, each worker will collect data from different dataset in order to speed up the whole loading process. + See ref:`streamdataset-example` for more details + .. admonition:: The effect of enabling preload :class: warning diff --git a/imperative/python/megengine/data/dataset/meta_dataset.py b/imperative/python/megengine/data/dataset/meta_dataset.py index be8f921a..62ad4893 100644 --- a/imperative/python/megengine/data/dataset/meta_dataset.py +++ b/imperative/python/megengine/data/dataset/meta_dataset.py @@ -58,6 +58,33 @@ class StreamDataset(Dataset): r"""An abstract class for stream data. __iter__ method is aditionally needed. + + Examples: + + .. code-block:: python + + from megengine.data.dataset import StreamDataset + from megengine.data.dataloader import DataLoader, get_worker_info + from megengine.data.sampler import StreamSampler + + class MyStream(StreamDataset): + def __init__(self): + self.data = [iter([1, 2, 3]), iter([4, 5, 6]), iter([7, 8, 9])] + def __iter__(self): + worker_info = get_worker_info() + data_iter = self.data[worker_info.idx] + while True: + yield next(data_iter) + + dataloader = DataLoader( + dataset = MyStream(), + sampler = StreamSampler(batch_size=2), + num_workers=3, + parallel_stream = True, + ) + + for step, data in enumerate(dataloader): + print(data) """ @abstractmethod @@ -80,6 +107,29 @@ class ArrayDataset(Dataset): One or more numpy arrays are needed to initiate the dataset. And the dimensions represented sample number are expected to be the same. + + Examples: + + .. code-block:: python + + from megengine.data.dataset import ArrayDataset + from megengine.data.dataloader import DataLoader + from megengine.data.sampler import SequentialSampler + + rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8) + label = np.random.randint(0, 10, size=(sample_num,), dtype=int) + dataset = ArrayDataset(rand_data, label) + seque_sampler = SequentialSampler(dataset, batch_size=2) + + dataloader = DataLoader( + dataset, + sampler = seque_sampler, + num_workers=3, + ) + + for step, data in enumerate(dataloader): + print(data) + """ def __init__(self, *arrays):