|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789 |
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import collections
- import gc
- import math
- import multiprocessing
- import platform
- import queue
- import random
- import threading
- import time
- from typing import Callable
-
- import numpy as np
-
- from ..logger import get_logger
- from ..random.rng import _random_seed_generator
- from .collator import Collator
- from .dataset import Dataset, StreamDataset
- from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
- from .transform import PseudoTransform, Transform
-
- try:
- import thread
- except:
- import _thread as thread
-
-
- logger = get_logger(__name__)
-
-
- GLOBAL_TIMEOUT = 5
-
-
- def raise_timeout_error():
- raise RuntimeError("dataloader timeout")
-
-
- class DataLoader:
- r"""Provides a convenient way to iterate on a given dataset.
-
- DataLoader combines a dataset with
- :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
- make it flexible to get minibatch continually from a dataset.
-
- :param dataset: dataset from which to load the minibatch.
- :param sampler: defines the strategy to sample data from the dataset.
- :param transform: defined the transforming strategy for a sampled batch.
- Default: None
- :param collator: defined the merging strategy for a transformed batch.
- Default: None
- :param num_workers: the number of sub-process to load, transform and collate
- the batch. ``0`` means using single-process. Default: 0
- :param timeout: if positive, means the timeout value(second) for collecting a
- batch from workers. Default: 0
- :param timeout_event: callback function triggered by timeout, default to raise
- runtime error.
- :param divide: define the paralleling strategy in multi-processing mode.
- ``True`` means one batch is divided into :attr:`num_workers` pieces, and
- the workers will process these pieces parallelly. ``False`` means
- different sub-process will process different batch. Default: False
-
- """
- __initialized = False
-
- def __init__(
- self,
- dataset: Dataset,
- sampler: Sampler = None,
- transform: Transform = None,
- collator: Collator = None,
- num_workers: int = 0,
- timeout: int = 0,
- timeout_event: Callable = raise_timeout_error,
- divide: bool = False,
- ):
- if num_workers < 0:
- raise ValueError("num_workers should not be negative")
-
- if timeout < 0:
- raise ValueError("timeout should not be negative")
-
- if divide and num_workers <= 1:
- raise ValueError("divide should not be set to True when num_workers <= 1")
-
- self.dataset = dataset
-
- self.num_workers = num_workers
- self.timeout = timeout
- self.timeout_event = timeout_event
-
- self.divide = divide
-
- if isinstance(dataset, StreamDataset):
- self.sampler = sampler if sampler else StreamSampler(batch_size=1)
- assert isinstance(
- self.sampler, StreamSampler
- ), "types of dataset and sampler do not match"
- else:
- assert isinstance(
- dataset, Dataset
- ), "Can not recognize this kind of dataset: %s" % type(dataset)
- self.sampler = (
- sampler
- if sampler
- else SequentialSampler(dataset, batch_size=1, drop_last=False)
- )
- assert isinstance(
- self.sampler, MapSampler
- ), "types of dataset and sampler do not match"
-
- if divide:
- if self.sampler.batch_size <= self.num_workers:
- raise ValueError(
- "batch size must not smaller than num_workers in divide mode."
- )
- elif self.sampler.batch_size % self.num_workers:
- logger.warning(
- "batch size is not divisible by num_workers, may lose performance in divide mode."
- )
-
- if transform is None:
- self.transform = PseudoTransform()
- else:
- self.transform = transform
-
- if collator is None:
- self.collator = Collator()
- else:
- self.collator = collator
-
- self.__initialized = True
-
- def __iter__(self):
- if platform.system() == "Windows" and self.num_workers > 0:
- print(
- "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
- )
- self.num_workers = 0
- if isinstance(self.dataset, StreamDataset):
- if not self.num_workers:
- return _SerialStreamDataLoaderIter(self)
- else:
- return _ParallelStreamDataLoaderIter(self)
- else:
- assert isinstance(
- self.dataset, Dataset
- ), "Can not recognize this kind of dataset: %s" % type(self.dataset)
- if not self.num_workers:
- return _SerialMapDataLoaderIter(self)
- else:
- return _ParallelMapDataLoaderIter(self)
-
- def __len__(self):
- return len(self.sampler)
-
-
- class _BaseMapDataLoaderIter:
- def __init__(self, loader):
- self.dataset = loader.dataset
- self.sampler = loader.sampler
- self.seed = _random_seed_generator().__next__()
- self.transform = loader.transform
- self.collator = loader.collator
- self.num_workers = loader.num_workers
- self.timeout = loader.timeout
- self.timeout_event = loader.timeout_event
- self.divide = loader.divide
- self.num_processed = 0
-
- def _get_next_batch(self):
- raise NotImplementedError
-
- def __len__(self):
- return len(self.sampler)
-
- def __iter__(self):
- return self
-
- def __next__(self):
- if self.num_processed >= len(self):
- raise StopIteration
- minibatch = self._get_next_batch()
- self.num_processed += 1
- return minibatch
-
-
- class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
- def __init__(self, loader):
- super(_SerialMapDataLoaderIter, self).__init__(loader)
- self.indices_iter = iter(self.sampler)
-
- def _get_next_batch(self):
- indices = next(self.indices_iter)
- items = [self.dataset[idx] for idx in indices]
- trans_items = self.transform.apply_batch(items)
- return self.collator.apply(trans_items)
-
-
- class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
- __initialized = False
-
- def __init__(self, loader):
- super(_ParallelMapDataLoaderIter, self).__init__(loader)
-
- self.task_queues = [
- multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
- ]
-
- self.feed_batch_idx = multiprocessing.Value("i", 0)
- self.target_batch_idx = multiprocessing.Value("i", 0)
- self.shutdown_flag = multiprocessing.Value("i", 0)
-
- self.trans_data_queues = [
- multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
- ]
-
- # use shared-memory queue implemented by pyarrow plasma store.
- from .tools._queue import PlasmaShmQueue
-
- self.batch_queue = PlasmaShmQueue(maxsize=2)
-
- self.task_feeding_worker = multiprocessing.Process(
- target=_task_feeding_loop,
- args=(
- iter(self.sampler),
- self.task_queues,
- self.num_workers,
- self.divide,
- self.shutdown_flag,
- self.feed_batch_idx,
- ),
- daemon=True,
- )
- gc.collect()
- self.task_feeding_worker.start()
-
- self.workers = []
- for worker_id in range(self.num_workers):
- worker = multiprocessing.Process(
- target=_worker_loop,
- args=(
- self.dataset,
- self.task_queues[worker_id],
- self.trans_data_queues[worker_id],
- self.transform,
- self.seed + worker_id + 1,
- self.shutdown_flag,
- ),
- daemon=True,
- )
- gc.collect()
- worker.start()
- self.workers.append(worker)
-
- if self.divide:
- self.data_collecting_worker = multiprocessing.Process(
- target=_data_gathering_loop,
- args=(
- self.trans_data_queues,
- self.batch_queue,
- self.collator,
- len(self),
- self.num_workers,
- self.shutdown_flag,
- self.target_batch_idx,
- ),
- daemon=True,
- )
- else:
- self.data_collecting_worker = multiprocessing.Process(
- target=_data_selecting_loop,
- args=(
- self.trans_data_queues,
- self.batch_queue,
- self.collator,
- len(self),
- self.num_workers,
- self.shutdown_flag,
- self.target_batch_idx,
- ),
- daemon=True,
- )
- gc.collect()
- self.data_collecting_worker.start()
-
- self.__initialized = True
-
- def _check_workers(self):
- # Check the status of each worker.
- if not self.data_collecting_worker.is_alive():
- exitcode = self.task_feeding_worker.exitcode
- if exitcode != 0:
- raise RuntimeError("data collecting worker died. {}".format(exitcode))
-
- if not self.task_feeding_worker.is_alive():
- exitcode = self.task_feeding_worker.exitcode
- if exitcode != 0:
- raise RuntimeError("task feeding worker died. {}".format(exitcode))
-
- for worker_id, worker in enumerate(self.workers):
- if not worker.is_alive():
- exitcode = worker.exitcode
- if exitcode != 0:
- raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
-
- logger.debug("all workers are alive.")
-
- def _get_next_batch(self):
- start_time = time.time()
- while True:
- self._check_workers()
- try:
- return self.batch_queue.get(timeout=1)
- except queue.Empty:
- logger.debug("batch queue empty!")
- waited_time = time.time() - start_time
- if self.timeout > 0:
- if waited_time > self.timeout:
- raise RuntimeError("get_next_batch timeout!")
-
- def _shutdown(self):
- with self.shutdown_flag.get_lock():
- self.shutdown_flag.value = 1
-
- if self.task_feeding_worker.is_alive():
- self.task_feeding_worker.terminate()
- self.task_feeding_worker.join()
-
- if self.data_collecting_worker.is_alive():
- self.data_collecting_worker.terminate()
- self.data_collecting_worker.join()
-
- for worker in self.workers:
- if worker.is_alive():
- worker.terminate()
- worker.join()
-
- for q in self.trans_data_queues:
- q.cancel_join_thread()
- q.close()
-
- for q in self.task_queues:
- q.cancel_join_thread()
- q.close()
-
- self.batch_queue.cancel_join_thread()
- self.batch_queue.close()
-
- def __del__(self):
- if self.__initialized:
- self._shutdown()
-
-
- class _BaseStreamDataLoaderIter:
- def __init__(self, loader):
- self.dataset = loader.dataset
- self.sampler = loader.sampler
- self.transform = loader.transform
- self.collator = loader.collator
- self.num_workers = loader.num_workers
- self.timeout = loader.timeout
- self.timeout_event = loader.timeout_event
-
- def _get_next_batch(self):
- raise NotImplementedError
-
- def _process_raw_data(self, raw_data):
- assert len(raw_data) == 2 and isinstance(
- raw_data[0], bool
- ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
- if not raw_data[0]:
- data = list((x,) for x in raw_data[1])
- else:
- data = raw_data[1]
- ret = []
- for idx in range(len(data[0])):
- ret.append(tuple(e[idx] for e in data))
- return ret
-
- def __iter__(self):
- return self
-
- def __next__(self):
- return self._get_next_batch()
-
-
- class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
- def __init__(self, loader):
- super().__init__(loader)
- self.dataset_iter = iter(self.dataset)
- self.idx = 0
- self.unused = []
-
- def _try_get_raw_data(self, start_time):
- raw_data = None
- while not raw_data:
- try:
- if self.timeout > 0:
- timer = threading.Timer(self.timeout, thread.interrupt_main)
- timer.start()
- raw_data = next(self.dataset_iter)
- if self.timeout > 0:
- timer.cancel()
- except KeyboardInterrupt:
- raw_data = self.timeout_event()
- except:
- if self.timeout > 0:
- timer.cancel()
- waited_time = time.time() - start_time
- if waited_time > self.timeout:
- raw_data = self.timeout_event()
- return raw_data
-
- def _get_next_batch(self):
- ret = []
- start_time = time.time()
- while len(ret) < self.sampler.batch_size:
- if len(self.unused) != 0:
- batch_data = self.unused
- else:
- raw_data = self._try_get_raw_data(start_time)
- batch_data = self._process_raw_data(raw_data)
-
- while len(batch_data) != 0 and len(ret) < self.sampler.batch_size:
- data = batch_data.pop()
- ret.append(self.transform.apply(data))
- self.unused = batch_data
-
- return self.collator.apply(ret)
-
-
- class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
- __initialized = False
-
- def __init__(self, loader):
- super().__init__(loader)
-
- self.shutdown_flag = multiprocessing.Value("i", 0)
-
- self.raw_data_queues = [
- multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
- ]
-
- self.trans_data_queues = [
- multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
- ]
-
- # shared-memory queue implemented by pyarrow plasma store
- from .tools._queue import PlasmaShmQueue
-
- self.batch_queue = PlasmaShmQueue(maxsize=2)
-
- self.recieve_worker = multiprocessing.Process(
- target=self._worker_to_raw_data_queues, daemon=True
- )
- gc.collect()
- self.recieve_worker.start()
-
- self.transform_workers = []
- for worker_id in range(self.num_workers):
- worker = multiprocessing.Process(
- target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
- )
- gc.collect()
- worker.start()
- self.transform_workers.append(worker)
-
- self.collect_worker = multiprocessing.Process(
- target=self._worker_to_batch_queue, daemon=True
- )
- gc.collect()
- self.collect_worker.start()
-
- self.__initialized = True
-
- def _put_raw_data_queues(self, raw_data, qidx):
- batch_data = self._process_raw_data(raw_data)
- for data in batch_data:
- while True:
- qidx = qidx % self.num_workers
- try:
- self.raw_data_queues[qidx].put(data)
- break
- except queue.Full:
- if self.shutdown_flag.value == 1:
- break
- logger.debug("raw data queue %d is full" % qidx)
- finally:
- qidx += 1
- return qidx
-
- def _worker_to_raw_data_queues(self):
- dataset_iter = iter(self.dataset)
- qidx = 0
- while True:
- if self.shutdown_flag.value == 1:
- break
- raw_data = next(dataset_iter)
- qidx = self._put_raw_data_queues(raw_data, qidx)
-
- def _worker_to_trans_data_queues(self, worker_id):
- while True:
- if self.shutdown_flag.value == 1:
- break
- try:
- data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
- except queue.Empty:
- continue
- trans_data = self.transform.apply(data)
- while True:
- try:
- self.trans_data_queues[worker_id].put(trans_data)
- break
- except queue.Full:
- if self.shutdown_flag.value == 1:
- break
- logger.debug("batch queue if full")
-
- def _worker_to_batch_queue(self):
- cnt = -1
- trans_items = []
- while True:
- if self.shutdown_flag.value == 1:
- break
- cnt += 1
- queue_id = cnt % self.num_workers
- try:
- trans_item = self.trans_data_queues[queue_id].get(
- timeout=GLOBAL_TIMEOUT
- )
- except queue.Empty:
- continue
- trans_items.append(trans_item)
- if len(trans_items) == self.sampler.batch_size:
- batch_data = self.collator.apply(trans_items)
- while True:
- try:
- self.batch_queue.put(batch_data, timeout=1)
- break
- except queue.Full:
- if self.shutdown_flag.value == 1:
- break
- logger.debug("batch queue is full")
- trans_items = []
-
- def _check_workers(self):
- if not self.collect_worker.is_alive():
- exitcode = self.collect_worker.exitcode
- if exitcode != 0:
- raise RuntimeError("collator worker died. {}".format(exitcode))
-
- for worker_id, worker in enumerate(self.transform_workers):
- if not worker.is_alive():
- exitcode = worker.exitcode
- if exitcode != 0:
- raise RuntimeError(
- "worker: {} died. {}".format(worker_id, exitcode)
- )
-
- def _get_next_batch(self):
- start_time = time.time()
- while True:
- self._check_workers()
- try:
- return self.batch_queue.get(timeout=1)
- except queue.Empty:
- logger.debug("batch queue empty!")
- waited_time = time.time() - start_time
- if self.timeout > 0 and waited_time > self.timeout:
- self._put_raw_data_queues(self.timeout_event(), 0)
-
- def _shutdown(self):
- with self.shutdown_flag.get_lock():
- self.shutdown_flag.value = 1
-
- if self.recieve_worker.is_alive():
- self.recieve_worker.terminate()
- self.recieve_worker.join()
-
- if self.collect_worker.is_alive():
- self.collect_worker.terminate()
- self.collect_worker.join()
-
- for worker in self.transform_workers:
- if worker.is_alive():
- worker.terminate()
- worker.join()
-
- for q in self.raw_data_queues:
- q.cancel_join_thread()
- q.close()
-
- for q in self.trans_data_queues:
- q.cancel_join_thread()
- q.close()
-
- self.batch_queue.cancel_join_thread()
- self.batch_queue.close()
-
- def __del__(self):
- if self.__initialized:
- self._shutdown()
-
-
- def _task_feeding_loop(
- indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
- ):
- # Feed the indices into the task queues
- while True:
- if shutdown_flag.value == 1:
- break
- batch_idx = feed_batch_idx.value
- try:
- indices = next(indices_iter)
- except StopIteration:
- break
- if divide:
- # make sure all task_queues is ready for put
- while any([q.full() for q in task_queues]):
- if shutdown_flag.value == 1:
- return
- # divide into small pieces, feed to different workers.
- sub_num = math.ceil(len(indices) / num_workers)
- for worker_id in range(num_workers):
- sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
- task_queues[worker_id].put((batch_idx, sub_indices))
- else:
- # distribute tasks to different workers uniformly.
- target_id = batch_idx % num_workers
- while task_queues[target_id].full():
- if shutdown_flag.value == 1:
- return
- task_queues[target_id].put((batch_idx, indices))
- with feed_batch_idx.get_lock():
- feed_batch_idx.value += 1
-
-
- def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
- # Get dataset items and do the transform
- random.seed(seed)
- np.random.seed(seed)
- while True:
- if shutdown_flag.value == 1:
- break
- try:
- batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
- except queue.Empty:
- continue
- if len(indices) > 0:
- items = [dataset[idx] for idx in indices]
- trans_items = transform.apply_batch(items)
- else:
- # in case of incomplete last batch
- trans_items = ()
- while True:
- try:
- trans_data_queue.put((batch_idx, trans_items), timeout=1)
- break
- except queue.Full:
- if shutdown_flag.value == 1:
- break
- logger.debug("batch part queue is full!")
-
-
- def _data_gathering_loop(
- trans_data_queues,
- batch_queue,
- collator,
- length,
- num_workers,
- shutdown_flag,
- target_idx,
- ):
- # Gathering the small pieces of batch data into full batch data
- while True:
- if shutdown_flag.value == 1:
- break
-
- target_batch_idx = target_idx.value
-
- if target_batch_idx >= length:
- break
-
- full_trans_items = []
- for worker_id in range(num_workers):
- while True:
- try:
- batch_idx, trans_items = trans_data_queues[worker_id].get(
- timeout=GLOBAL_TIMEOUT
- )
- break
- except queue.Empty:
- if shutdown_flag.value == 1:
- break
- logger.debug(
- "worker:{} data queue get timeout! target batch idx:{}".format(
- worker_id, target_batch_idx
- )
- )
- if batch_idx != target_batch_idx:
- raise RuntimeError(
- "Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
- worker_id
- )
- )
- else:
- full_trans_items.extend(trans_items)
-
- # Merge different parts into a batch.
- full_batch = collator.apply(full_trans_items)
-
- while True:
- try:
- batch_queue.put(full_batch, timeout=1)
- break
- except queue.Full:
- if shutdown_flag.value == 1:
- break
- logger.debug("batch queue is full!")
-
- with target_idx.get_lock():
- target_idx.value += 1
-
- batch_queue.disconnect_client()
-
-
- def _data_selecting_loop(
- trans_data_queues,
- batch_queue,
- collator,
- length,
- num_workers,
- shutdown_flag,
- target_idx,
- ):
- # Make sure that batch is generated exactly with the same order as generated indices
- while True:
- if shutdown_flag.value == 1:
- break
-
- target_batch_idx = target_idx.value
-
- if target_batch_idx >= length:
- break
-
- target_worker_id = target_batch_idx % num_workers
- while True:
- try:
- batch_idx, trans_items = trans_data_queues[target_worker_id].get(
- timeout=GLOBAL_TIMEOUT
- )
- batch_data = collator.apply(trans_items)
- break
- except queue.Empty:
- if shutdown_flag.value == 1:
- break
- logger.debug(
- "worker:{} data queue get timeout! target batch idx:{}".format(
- target_worker_id, target_batch_idx
- )
- )
-
- if batch_idx != target_batch_idx:
- raise RuntimeError(
- "batch_idx {} mismatch the target_batch_idx {}".format(
- batch_idx, target_batch_idx
- )
- )
-
- while True:
- try:
- batch_queue.put(batch_data, timeout=1)
- break
- except queue.Full:
- if shutdown_flag.value == 1:
- break
- logger.debug("batch queue is full!")
-
- with target_idx.get_lock():
- target_idx.value += 1
-
- batch_queue.disconnect_client()
|