# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections import math import multiprocessing import queue import random import time import numpy as np import megengine as mge from ..logger import get_logger from .collator import Collator from .dataset import Dataset from .sampler import Sampler, SequentialSampler from .transform import PseudoTransform, Transform logger = get_logger(__name__) MP_QUEUE_GET_TIMEOUT = 5 class DataLoader: __initialized = False def __init__( self, dataset: Dataset, sampler: Sampler = None, transform: Transform = None, collator: Collator = None, num_workers: int = 0, timeout: int = 0, divide: bool = False, ): r"""Provides a convenient way to iterate on a given dataset. `DataLoader` combines a dataset with sampler, transform and collator, make it flexible to get minibatch continually from a dataset. :type dataset: Dataset :param dataset: dataset from which to load the minibatch. :type sampler: Sampler :param sampler: defines the strategy to sample data from the dataset. If specified, :attr:`shuffle` must be ``False``. :type transform: Transform :param transform: defined the transforming strategy for a sampled batch. (default: ``None``) :type collator: Collator :param collator: defined the merging strategy for a transformed batch. (default: ``None``) :type num_workers: int :param num_workers: the number of sub-process to load, transform and collate the batch. ``0`` means using single-process. (default: ``0``) :type timeout: int :param timeout: if positive, means the timeout value(second) for collecting a batch from workers. (default: 0) :type divide: bool :param divide: define the paralleling strategy in multi-processing mode. ``True`` means one batch is divided into :attr:`num_workers` pieces, and the workers will process these pieces parallelly. ``False`` means different sub-process will process different batch. (default: ``False``) """ if num_workers < 0: raise ValueError("num_workers should not be negative") if timeout < 0: raise ValueError("timeout should not be negative") if divide and num_workers <= 1: raise ValueError("divide should not be set to True when num_workers <= 1") self.dataset = dataset self.num_workers = num_workers self.timeout = timeout self.divide = divide self.rng = np.random.RandomState() if sampler is None: self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) else: self.sampler = sampler if divide: if self.sampler.batch_size <= self.num_workers: raise ValueError( "batch size must not smaller than num_workers in divide mode." ) elif self.sampler.batch_size % self.num_workers: logger.warning( "batch size is not divisible by num_workers, may lose performance in divide mode." ) if transform is None: self.transform = PseudoTransform() else: self.transform = transform if collator is None: self.collator = Collator() else: self.collator = collator self.__initialized = True def __iter__(self): if self.num_workers == 0: return _SerialDataLoaderIter(self) else: return _ParallelDataLoaderIter(self) def __len__(self): return len(self.sampler) class _BaseDataLoaderIter: def __init__(self, loader): self.dataset = loader.dataset self.sampler = loader.sampler self.seed = loader.rng.randint(1e9) self.transform = loader.transform self.collator = loader.collator self.num_workers = loader.num_workers self.timeout = loader.timeout self.divide = loader.divide self.num_processed = 0 def _get_next_batch(self): raise NotImplementedError def __len__(self): return len(self.sampler) def __iter__(self): return self def __next__(self): if self.num_processed >= len(self): raise StopIteration minibatch = self._get_next_batch() self.num_processed += 1 return minibatch class _SerialDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_SerialDataLoaderIter, self).__init__(loader) self.indices_iter = iter(self.sampler) def _get_next_batch(self): indices = next(self.indices_iter) items = [self.dataset[idx] for idx in indices] trans_items = self.transform.apply_batch(items) return self.collator.apply(trans_items) class _ParallelDataLoaderIter(_BaseDataLoaderIter): __initialized = False def __init__(self, loader): super(_ParallelDataLoaderIter, self).__init__(loader) # if any worker died, all workers will be shutdown. self.strict = True # TODO: put `strict` into DataLoader args or not? self.task_queues = [ multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) ] self.feed_batch_idx = multiprocessing.Value("i", 0) self.target_batch_idx = multiprocessing.Value("i", 0) self.shutdown_flag = multiprocessing.Value("i", 0) self.batch_part_queues = [ multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) ] # use shared-memory queue implemented by pyarrow plasma store. from ._queue import PlasmaShmQueue self.batch_queue = PlasmaShmQueue(maxsize=2) self.task_feeding_worker = multiprocessing.Process( target=self._task_feeding_loop, args=(iter(self.sampler), self.divide), daemon=True, ) self.task_feeding_worker.start() self.workers = [] for worker_id in range(self.num_workers): worker = multiprocessing.Process( target=self._worker_loop, args=( self.task_queues[worker_id], self.batch_part_queues[worker_id], self.transform, self.collator, self.seed + worker_id + 1, ), daemon=True, ) worker.start() self.workers.append(worker) if self.divide: self.data_collecting_worker = multiprocessing.Process( target=self._data_gathering_loop, args=(self.batch_part_queues, self.batch_queue,), daemon=True, ) else: self.data_collecting_worker = multiprocessing.Process( target=self._data_selecting_loop, args=(self.batch_part_queues, self.batch_queue,), daemon=True, ) self.data_collecting_worker.start() self.__initialized = True def _task_feeding_loop(self, indices_iter, divide): while True: if self.shutdown_flag.value == 1: break batch_idx = self.feed_batch_idx.value try: indices = next(indices_iter) except StopIteration: break if divide: # make sure all task_queues is ready for put while any([q.full() for q in self.task_queues]): if self.shutdown_flag.value == 1: return # divide into small pieces, feed to different workers. sub_num = math.ceil(len(indices) / self.num_workers) for worker_id in range(self.num_workers): sub_indices = indices[ worker_id * sub_num : (worker_id + 1) * sub_num ] self.task_queues[worker_id].put((batch_idx, sub_indices)) else: # distribute tasks to different workers uniformly. target_id = batch_idx % self.num_workers while self.task_queues[target_id].full(): if self.shutdown_flag.value == 1: return self.task_queues[target_id].put((batch_idx, indices)) with self.feed_batch_idx.get_lock(): self.feed_batch_idx.value += 1 def _worker_loop(self, task_queue, data_queue, transform, collator, seed): random.seed(seed) np.random.seed(seed) while True: if self.shutdown_flag.value == 1: break try: batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT) except queue.Empty: continue if len(indices) > 0: items = [self.dataset[idx] for idx in indices] trans_items = transform.apply_batch(items) batch_data = collator.apply(trans_items) else: # in case of incomplete last batch batch_data = () while True: try: data_queue.put((np.array([batch_idx]), batch_data), timeout=1) break except queue.Full: if self.shutdown_flag.value == 1: break logger.debug("batch part queue is full!") continue def _data_gathering_loop(self, batch_part_queues, batch_queue): r"""Gathering the small pieces of batch data into full batch data.""" gathered_data = collections.defaultdict(dict) while True: if self.shutdown_flag.value == 1: break target_batch_idx = self.target_batch_idx.value if target_batch_idx >= len(self): break for worker_id in range(self.num_workers): if worker_id in gathered_data[target_batch_idx]: continue while True: try: (batch_idx,), batch_part = batch_part_queues[worker_id].get( timeout=MP_QUEUE_GET_TIMEOUT ) break except queue.Empty: if self.shutdown_flag.value == 1: break logger.debug( "worker:{} data queue get timeout! target batch idx:{}".format( worker_id, target_batch_idx ) ) if batch_idx < target_batch_idx: raise RuntimeError( "Unexperted batch_idx in data gathering loop. worker_id:{}.".format( worker_id ) ) else: gathered_data[batch_idx][worker_id] = batch_part if len(gathered_data[target_batch_idx]) < self.num_workers: length = len(gathered_data[target_batch_idx]) if self.strict: raise RuntimeError("Parts missing in data gathering loop.") logger.warning( "target_batch_idx:{}, {} part(s) missing.".format( target_batch_idx, self.num_workers - length ) ) del gathered_data[target_batch_idx] with self.target_batch_idx.get_lock(): self.target_batch_idx.value += 1 continue # Merge different parts. full_batch = [[] for _ in range(len(gathered_data[target_batch_idx][0]))] for idx in range(self.num_workers): for i, field in enumerate(gathered_data[target_batch_idx][idx]): full_batch[i].append(field) full_batch = tuple([np.concatenate(field, axis=0) for field in full_batch]) while True: try: batch_queue.put(full_batch, timeout=1) break except queue.Full: if self.shutdown_flag.value == 1: break logger.debug("batch queue is full!") continue del gathered_data[target_batch_idx] with self.target_batch_idx.get_lock(): self.target_batch_idx.value += 1 batch_queue.disconnect_client() def _data_selecting_loop(self, batch_part_queues, batch_queue): r"""Make sure that batch is generated exactly with the same order as generated indices.""" buffer_batches = {} while True: if self.shutdown_flag.value == 1: break target_batch_idx = self.target_batch_idx.value if target_batch_idx >= len(self): break if target_batch_idx in buffer_batches: while True: try: batch_queue.put( buffer_batches[target_batch_idx], timeout=1, ) break except queue.Full: if self.shutdown_flag.value == 1: break logger.debug("batch queue is full!") with self.target_batch_idx.get_lock(): self.target_batch_idx.value += 1 del buffer_batches[target_batch_idx] continue target_worker_id = target_batch_idx % self.num_workers while True: try: (batch_idx,), batch_data = batch_part_queues[target_worker_id].get( timeout=MP_QUEUE_GET_TIMEOUT ) break except queue.Empty: if self.shutdown_flag.value == 1: break logger.debug( "worker:{} data queue get timeout! target batch idx:{}".format( target_worker_id, target_batch_idx ) ) if batch_idx < target_batch_idx: raise RuntimeError("batch_idx smaller than target_batch_idx") elif batch_idx > target_batch_idx: if self.strict: raise RuntimeError("batch_idx larger than target_batch_idx") logger.warning( "missing target batch idx:{}, batch idx:{}".format( target_batch_idx, batch_idx ) ) buffer_batches[batch_idx] = batch_data else: try: batch_queue.put(batch_data, timeout=1) except queue.Full: buffer_batches[batch_idx] = batch_data continue with self.target_batch_idx.get_lock(): self.target_batch_idx.value += 1 batch_queue.disconnect_client() def _check_workers(self): """Check the status of each worker and restart if necessary.""" if not self.data_collecting_worker.is_alive(): exitcode = self.task_feeding_worker.exitcode if exitcode != 0: raise RuntimeError("data collecting worker died. {}".format(exitcode)) if self.strict: if not self.task_feeding_worker.is_alive(): exitcode = self.task_feeding_worker.exitcode if exitcode != 0: raise RuntimeError("task feeding worker died. {}".format(exitcode)) for worker_id, worker in enumerate(self.workers): if not worker.is_alive(): exitcode = worker.exitcode if exitcode != 0: raise RuntimeError( "worker:{} died. {}".format(worker_id, exitcode) ) else: if not self.task_feeding_worker.is_alive(): exitcode = self.task_feeding_worker.exitcode if exitcode != 0: logger.error( "task feeding worker died {}. Restarting".format(exitcode) ) self.task_feeding_worker.join() self.task_feeding_worker = multiprocessing.Process( target=self._task_feeding_loop, args=(iter(self.sampler), self.divide), daemon=True, ) self.task_feeding_worker.start() failed_num = 0 for worker_id in range(self.num_workers): if self.workers[worker_id].is_alive(): continue exitcode = worker.exitcode if exitcode == 0: continue logger.error("worker {} died. Restarting".format(worker_id)) failed_num += 1 self.workers[worker_id].join() worker = multiprocessing.Process( target=self._worker_loop, args=( self.task_queues[worker_id], self.batch_part_queues[worker_id], self.transform, self.collator, self.seed + worker_id + 1, ), daemon=True, ) worker.start() self.workers[worker_id] = worker if failed_num > 0: logger.error("{} worker had exited".format(failed_num)) else: logger.debug("all workers are alive.") def _try_get_next_batch(self): start_time = time.time() while True: self._check_workers() try: return self.batch_queue.get(timeout=1) except queue.Empty: logger.debug("batch queue empty!") waited_time = time.time() - start_time if self.timeout > 0: if waited_time > self.timeout: raise RuntimeError("get_next_batch timeout!") def _get_next_batch(self): batch_data = self._try_get_next_batch() return batch_data def _shutdown(self): with self.shutdown_flag.get_lock(): self.shutdown_flag.value = 1 if self.task_feeding_worker.is_alive(): self.task_feeding_worker.terminate() self.task_feeding_worker.join() if self.data_collecting_worker.is_alive(): self.data_collecting_worker.terminate() self.data_collecting_worker.join() for worker in self.workers: if worker.is_alive(): worker.terminate() worker.join() for q in self.batch_part_queues: q.cancel_join_thread() q.close() for q in self.task_queues: q.cancel_join_thread() q.close() self.batch_queue.cancel_join_thread() self.batch_queue.close() def __del__(self): if self.__initialized: self._shutdown()