From f04e0d777e6b2bd4d1f9ee5adda9fc2f83eebdb2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 21 Oct 2020 13:46:35 +0800 Subject: [PATCH] feat(mge/data): dpflow dataset, stream sampler and loader GitOrigin-RevId: cbb4510a13625e7c2203cd1358a96208849029ca --- imperative/python/megengine/data/__init__.py | 1 + imperative/python/megengine/data/dataloader.py | 209 +++++++++++++++++++++++-- imperative/python/megengine/data/sampler.py | 42 ++++- 3 files changed, 236 insertions(+), 16 deletions(-) diff --git a/imperative/python/megengine/data/__init__.py b/imperative/python/megengine/data/__init__.py index 3b1e0d55..11398efe 100644 --- a/imperative/python/megengine/data/__init__.py +++ b/imperative/python/megengine/data/__init__.py @@ -14,4 +14,5 @@ from .sampler import ( ReplacementSampler, Sampler, SequentialSampler, + StreamSampler, ) diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index 3bd01993..d6c55422 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -19,8 +19,8 @@ import numpy as np from ..logger import get_logger from ..random.rng import _random_seed_generator from .collator import Collator -from .dataset import Dataset -from .sampler import Sampler, SequentialSampler +from .dataset import Dataset, MapDataset, StreamDataset +from .sampler import Sampler, SequentialSampler, StreamSampler from .transform import PseudoTransform, Transform logger = get_logger(__name__) @@ -82,13 +82,21 @@ class DataLoader: raise ValueError("divide should not be set to True when num_workers <= 1") self.dataset = dataset + self.num_workers = num_workers self.timeout = timeout self.divide = divide if sampler is None: - self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) + if isinstance(dataset, MapDataset): + self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) + elif isinstance(dataset, StreamDataset): + self.sampler = StreamSampler(batch_size=1) + else: + raise TypeError( + "can not recognize this kind of dataset: %s" % type(dataset) + ) else: self.sampler = sampler @@ -120,16 +128,26 @@ class DataLoader: "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero" ) self.num_workers = 0 - if self.num_workers == 0: - return _SerialDataLoaderIter(self) + if isinstance(self.dataset, StreamDataset): + if not self.num_workers: + return _SerialStreamDataLoaderIter(self) + else: + return _ParallelStreamDataLoaderIter(self) + elif isinstance(self.dataset, MapDataset): + if not self.num_workers: + return _SerialMapDataLoaderIter(self) + else: + return _ParallelMapDataLoaderIter(self) else: - return _ParallelDataLoaderIter(self) + raise TypeError( + "can not recognize this kind of dataset: %s" % type(self.dataset) + ) def __len__(self): return len(self.sampler) -class _BaseDataLoaderIter: +class _BaseMapDataLoaderIter: def __init__(self, loader): self.dataset = loader.dataset self.sampler = loader.sampler @@ -158,9 +176,9 @@ class _BaseDataLoaderIter: return minibatch -class _SerialDataLoaderIter(_BaseDataLoaderIter): +class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): def __init__(self, loader): - super(_SerialDataLoaderIter, self).__init__(loader) + super(_SerialMapDataLoaderIter, self).__init__(loader) self.indices_iter = iter(self.sampler) def _get_next_batch(self): @@ -170,11 +188,11 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter): return self.collator.apply(trans_items) -class _ParallelDataLoaderIter(_BaseDataLoaderIter): +class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): __initialized = False def __init__(self, loader): - super(_ParallelDataLoaderIter, self).__init__(loader) + super(_ParallelMapDataLoaderIter, self).__init__(loader) self.task_queues = [ multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) @@ -326,6 +344,175 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): self._shutdown() +class _BaseStreamDataLoaderIter: + def __init__(self, loader): + self.dataset = loader.dataset + self.sampler = loader.sampler + self.transform = loader.transform + self.collator = loader.collator + self.num_workers = loader.num_workers + self.timeout = loader.timeout + self.post_process = self.dataset.post_process + + def _get_next_batch(self): + raise NotImplementedError + + def __iter__(self): + return self + + def __next__(self): + return self.post_process(self._get_next_batch()) + + +class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): + def __init__(self, loader): + super().__init__(loader) + self.dataset_iter = iter(self.dataset) + + def _get_next_batch(self): + ret = [] + start_time = time.time() + while len(ret) != self.sampler.batch_size: + waited_time = time.time() - start_time + if self.timeout > 0 and waited_time > self.timeout: + raise RuntimeError("get_next_batch timeout!") + item = next(self.dataset_iter) + for idx in range(len(item[0])): + trans_item = self.transform.apply(tuple(e[idx] for e in item)) + ret.append(trans_item) + if len(ret) == self.sampler.batch_size: + break + return self.collator.apply(ret) + + +class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): + __initialized = False + + def __init__(self, loader): + super().__init__(loader) + + self.shutdown_flag = multiprocessing.Value("i", 0) + + # shared-memory queue implemented by pyarrow plasma store + from ._queue import PlasmaShmQueue + + self.batch_queue = PlasmaShmQueue(maxsize=2) + self.workers = [] + self.worker_queues = [ + multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) + ] + for worker_id in range(self.num_workers): + worker = multiprocessing.Process( + target=self._gen_data, args=(worker_id,), daemon=True + ) + worker.start() + self.workers.append(worker) + self.collator_worker = multiprocessing.Process( + target=self._gen_batch, daemon=True + ) + self.collator_worker.start() + + self.__initialized = True + + def _gen_data(self, worker_id): + dataset_iter = iter(self.dataset) + while True: + if self.shutdown_flag.value == 1: + break + item = next(dataset_iter) + for idx in range(len(item[0])): + trans_item = self.transform.apply(tuple(e[idx] for e in item)) + while True: + try: + self.worker_queues[worker_id].put(trans_item) + break + except queue.Full: + if self.shutdown_flag.value == 1: + break + logger.debug("batch part queue is full") + + def _gen_batch(self): + cnt = -1 + trans_items = [] + while True: + if self.shutdown_flag.value == 1: + break + cnt += 1 + queue_id = cnt % self.num_workers + try: + trans_item = self.worker_queues[queue_id].get( + timeout=MP_QUEUE_GET_TIMEOUT + ) + except queue.Empty: + continue + trans_items.append(trans_item) + if len(trans_items) == self.sampler.batch_size: + batch_data = self.collator.apply(trans_items) + while True: + try: + self.batch_queue.put(batch_data, timeout=1) + break + except queue.Full: + if self.shutdown_flag.value == 1: + break + logger.debug("batch queue is full") + trans_items = [] + + def _check_workers(self): + if not self.collator_worker.is_alive(): + exitcode = self.collator_worker.exitcode + if exitcode != 0: + raise RuntimeError("collator worker died. {}".format(exitcode)) + + for worker_id, worker in enumerate(self.workers): + if not worker.is_alive(): + exitcode = worker.exitcode + if exitcode != 0: + raise RuntimeError( + "worker: {} died. {}".format(worker_id, exitcode) + ) + + def _try_get_next_batch(self): + start_time = time.time() + while True: + self._check_workers() + try: + return self.batch_queue.get(timeout=1) + except queue.Empty: + logger.debug("batch queue empty!") + waited_time = time.time() - start_time + if self.timeout > 0 and waited_time > self.timeout: + raise RuntimeError("get_next_batch timeout!") + + def _get_next_batch(self): + batch_data = self._try_get_next_batch() + return batch_data + + def _shutdown(self): + with self.shutdown_flag.get_lock(): + self.shutdown_flag.value = 1 + + if self.collator_worker.is_alive(): + self.collator_worker.terminate() + self.collator_worker.join() + + for worker in self.workers: + if worker.is_alive(): + worker.terminate() + worker.join() + + for q in self.worker_queues: + q.cancel_join_thread() + q.close() + + self.batch_queue.cancel_join_thread() + self.batch_queue.close() + + def __del__(self): + if self.__initialized: + self._shutdown() + + def _task_feeding_loop( indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx ): diff --git a/imperative/python/megengine/data/sampler.py b/imperative/python/megengine/data/sampler.py index 3a748ae7..a260fa87 100644 --- a/imperative/python/megengine/data/sampler.py +++ b/imperative/python/megengine/data/sampler.py @@ -8,7 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections.abc import math -from abc import ABC +from abc import ABC, abstractmethod from typing import Any, Generator, Iterator, List, Union import numpy as np @@ -17,6 +17,16 @@ import megengine.distributed as dist class Sampler(ABC): + r""" + An abstract class for all Sampler + """ + + @abstractmethod + def __init__(self): + pass + + +class MapSampler(Sampler): def __init__( self, dataset, @@ -145,7 +155,29 @@ class Sampler(ABC): return iter(batch_index) -class SequentialSampler(Sampler): +class StreamSampler(Sampler): + """ + Sampler for stream dataset. + + .. warning:: + + In the case of multiple workers, sampler should ensure that each worker gets + different data. But this class cannot do it yet, please build your own + dataset and sampler to achieve this goal. + + """ + + def __init__(self, batch_size=1): + self.batch_size = batch_size + + def __iter__(self): + return self + + def __next__(self): + return range(self.batch_size) + + +class SequentialSampler(MapSampler): def __init__( self, dataset, @@ -176,7 +208,7 @@ class SequentialSampler(Sampler): return self.indices -class RandomSampler(Sampler): +class RandomSampler(MapSampler): def __init__( self, dataset, @@ -205,7 +237,7 @@ class RandomSampler(Sampler): return self.rng.permutation(self.indices).tolist() -class ReplacementSampler(Sampler): +class ReplacementSampler(MapSampler): def __init__( self, dataset, @@ -249,7 +281,7 @@ class ReplacementSampler(Sampler): return self.rng.multinomial(n, self.weights, self.num_samples).tolist() -class Infinite(Sampler): +class Infinite(MapSampler): r"""Infinite Sampler warper for basic sampler.""" def sample(self):