From 2c2caf33311fe41a21b236073130e7c26eb42fa5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 27 Mar 2020 10:44:48 +0800 Subject: [PATCH] refactor(mge/data/dataloader): refactor the implementation of parallel dataloader GitOrigin-RevId: 0554ee8427c7d892557422c1ee57597b7c88756b --- python_module/megengine/data/dataloader.py | 502 +++++++++++++---------------- 1 file changed, 228 insertions(+), 274 deletions(-) diff --git a/python_module/megengine/data/dataloader.py b/python_module/megengine/data/dataloader.py index 388a8f0e..1fd3482d 100644 --- a/python_module/megengine/data/dataloader.py +++ b/python_module/megengine/data/dataloader.py @@ -15,9 +15,8 @@ import time import numpy as np -import megengine as mge - from ..logger import get_logger +from ..random.rng import _random_seed_generator from .collator import Collator from .dataset import Dataset from .sampler import Sampler, SequentialSampler @@ -87,8 +86,6 @@ class DataLoader: self.divide = divide - self.rng = np.random.RandomState() - if sampler is None: self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) else: @@ -130,7 +127,7 @@ class _BaseDataLoaderIter: def __init__(self, loader): self.dataset = loader.dataset self.sampler = loader.sampler - self.seed = loader.rng.randint(1e9) + self.seed = _random_seed_generator().__next__() self.transform = loader.transform self.collator = loader.collator self.num_workers = loader.num_workers @@ -173,10 +170,6 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_ParallelDataLoaderIter, self).__init__(loader) - # if any worker died, all workers will be shutdown. - self.strict = True - # TODO: put `strict` into DataLoader args or not? - self.task_queues = [ multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) ] @@ -185,7 +178,7 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): self.target_batch_idx = multiprocessing.Value("i", 0) self.shutdown_flag = multiprocessing.Value("i", 0) - self.batch_part_queues = [ + self.trans_data_queues = [ multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) ] @@ -195,8 +188,15 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): self.batch_queue = PlasmaShmQueue(maxsize=2) self.task_feeding_worker = multiprocessing.Process( - target=self._task_feeding_loop, - args=(iter(self.sampler), self.divide), + target=_task_feeding_loop, + args=( + iter(self.sampler), + self.task_queues, + self.num_workers, + self.divide, + self.shutdown_flag, + self.feed_batch_idx, + ), daemon=True, ) self.task_feeding_worker.start() @@ -204,13 +204,14 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): self.workers = [] for worker_id in range(self.num_workers): worker = multiprocessing.Process( - target=self._worker_loop, + target=_worker_loop, args=( + self.dataset, self.task_queues[worker_id], - self.batch_part_queues[worker_id], + self.trans_data_queues[worker_id], self.transform, - self.collator, self.seed + worker_id + 1, + self.shutdown_flag, ), daemon=True, ) @@ -219,282 +220,55 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): if self.divide: self.data_collecting_worker = multiprocessing.Process( - target=self._data_gathering_loop, - args=(self.batch_part_queues, self.batch_queue,), + target=_data_gathering_loop, + args=( + self.trans_data_queues, + self.batch_queue, + self.collator, + len(self), + self.num_workers, + self.shutdown_flag, + self.target_batch_idx, + ), daemon=True, ) else: self.data_collecting_worker = multiprocessing.Process( - target=self._data_selecting_loop, - args=(self.batch_part_queues, self.batch_queue,), + target=_data_selecting_loop, + args=( + self.trans_data_queues, + self.batch_queue, + self.collator, + len(self), + self.num_workers, + self.shutdown_flag, + self.target_batch_idx, + ), daemon=True, ) self.data_collecting_worker.start() self.__initialized = True - def _task_feeding_loop(self, indices_iter, divide): - while True: - if self.shutdown_flag.value == 1: - break - batch_idx = self.feed_batch_idx.value - try: - indices = next(indices_iter) - except StopIteration: - break - if divide: - # make sure all task_queues is ready for put - while any([q.full() for q in self.task_queues]): - if self.shutdown_flag.value == 1: - return - # divide into small pieces, feed to different workers. - sub_num = math.ceil(len(indices) / self.num_workers) - for worker_id in range(self.num_workers): - sub_indices = indices[ - worker_id * sub_num : (worker_id + 1) * sub_num - ] - self.task_queues[worker_id].put((batch_idx, sub_indices)) - else: - # distribute tasks to different workers uniformly. - target_id = batch_idx % self.num_workers - while self.task_queues[target_id].full(): - if self.shutdown_flag.value == 1: - return - self.task_queues[target_id].put((batch_idx, indices)) - with self.feed_batch_idx.get_lock(): - self.feed_batch_idx.value += 1 - - def _worker_loop(self, task_queue, data_queue, transform, collator, seed): - random.seed(seed) - np.random.seed(seed) - while True: - if self.shutdown_flag.value == 1: - break - try: - batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT) - except queue.Empty: - continue - if len(indices) > 0: - items = [self.dataset[idx] for idx in indices] - trans_items = transform.apply_batch(items) - batch_data = collator.apply(trans_items) - else: - # in case of incomplete last batch - batch_data = () - while True: - try: - data_queue.put((np.array([batch_idx]), batch_data), timeout=1) - break - except queue.Full: - if self.shutdown_flag.value == 1: - break - logger.debug("batch part queue is full!") - continue - - def _data_gathering_loop(self, batch_part_queues, batch_queue): - r"""Gathering the small pieces of batch data into full batch data.""" - gathered_data = collections.defaultdict(dict) - while True: - if self.shutdown_flag.value == 1: - break - - target_batch_idx = self.target_batch_idx.value - - if target_batch_idx >= len(self): - break - - for worker_id in range(self.num_workers): - if worker_id in gathered_data[target_batch_idx]: - continue - while True: - try: - (batch_idx,), batch_part = batch_part_queues[worker_id].get( - timeout=MP_QUEUE_GET_TIMEOUT - ) - break - except queue.Empty: - if self.shutdown_flag.value == 1: - break - logger.debug( - "worker:{} data queue get timeout! target batch idx:{}".format( - worker_id, target_batch_idx - ) - ) - if batch_idx < target_batch_idx: - raise RuntimeError( - "Unexperted batch_idx in data gathering loop. worker_id:{}.".format( - worker_id - ) - ) - else: - gathered_data[batch_idx][worker_id] = batch_part - - if len(gathered_data[target_batch_idx]) < self.num_workers: - length = len(gathered_data[target_batch_idx]) - if self.strict: - raise RuntimeError("Parts missing in data gathering loop.") - logger.warning( - "target_batch_idx:{}, {} part(s) missing.".format( - target_batch_idx, self.num_workers - length - ) - ) - del gathered_data[target_batch_idx] - with self.target_batch_idx.get_lock(): - self.target_batch_idx.value += 1 - continue - - # Merge different parts. - full_batch = [[] for _ in range(len(gathered_data[target_batch_idx][0]))] - for idx in range(self.num_workers): - for i, field in enumerate(gathered_data[target_batch_idx][idx]): - full_batch[i].append(field) - full_batch = tuple([np.concatenate(field, axis=0) for field in full_batch]) - - while True: - try: - batch_queue.put(full_batch, timeout=1) - break - except queue.Full: - if self.shutdown_flag.value == 1: - break - logger.debug("batch queue is full!") - continue - - del gathered_data[target_batch_idx] - - with self.target_batch_idx.get_lock(): - self.target_batch_idx.value += 1 - - batch_queue.disconnect_client() - - def _data_selecting_loop(self, batch_part_queues, batch_queue): - r"""Make sure that batch is generated exactly with the same order as generated indices.""" - buffer_batches = {} - while True: - if self.shutdown_flag.value == 1: - break - - target_batch_idx = self.target_batch_idx.value - - if target_batch_idx >= len(self): - break - - if target_batch_idx in buffer_batches: - while True: - try: - batch_queue.put( - buffer_batches[target_batch_idx], timeout=1, - ) - break - except queue.Full: - if self.shutdown_flag.value == 1: - break - logger.debug("batch queue is full!") - with self.target_batch_idx.get_lock(): - self.target_batch_idx.value += 1 - del buffer_batches[target_batch_idx] - continue - - target_worker_id = target_batch_idx % self.num_workers - while True: - try: - (batch_idx,), batch_data = batch_part_queues[target_worker_id].get( - timeout=MP_QUEUE_GET_TIMEOUT - ) - break - except queue.Empty: - if self.shutdown_flag.value == 1: - break - logger.debug( - "worker:{} data queue get timeout! target batch idx:{}".format( - target_worker_id, target_batch_idx - ) - ) - - if batch_idx < target_batch_idx: - raise RuntimeError("batch_idx smaller than target_batch_idx") - elif batch_idx > target_batch_idx: - if self.strict: - raise RuntimeError("batch_idx larger than target_batch_idx") - logger.warning( - "missing target batch idx:{}, batch idx:{}".format( - target_batch_idx, batch_idx - ) - ) - buffer_batches[batch_idx] = batch_data - else: - try: - batch_queue.put(batch_data, timeout=1) - except queue.Full: - buffer_batches[batch_idx] = batch_data - continue - - with self.target_batch_idx.get_lock(): - self.target_batch_idx.value += 1 - - batch_queue.disconnect_client() - def _check_workers(self): - """Check the status of each worker and restart if necessary.""" + # Check the status of each worker. if not self.data_collecting_worker.is_alive(): exitcode = self.task_feeding_worker.exitcode if exitcode != 0: raise RuntimeError("data collecting worker died. {}".format(exitcode)) - if self.strict: - if not self.task_feeding_worker.is_alive(): - exitcode = self.task_feeding_worker.exitcode - if exitcode != 0: - raise RuntimeError("task feeding worker died. {}".format(exitcode)) - for worker_id, worker in enumerate(self.workers): - if not worker.is_alive(): - exitcode = worker.exitcode - if exitcode != 0: - raise RuntimeError( - "worker:{} died. {}".format(worker_id, exitcode) - ) - else: - if not self.task_feeding_worker.is_alive(): - exitcode = self.task_feeding_worker.exitcode - if exitcode != 0: - logger.error( - "task feeding worker died {}. Restarting".format(exitcode) - ) - self.task_feeding_worker.join() - self.task_feeding_worker = multiprocessing.Process( - target=self._task_feeding_loop, - args=(iter(self.sampler), self.divide), - daemon=True, - ) - self.task_feeding_worker.start() - failed_num = 0 - for worker_id in range(self.num_workers): - if self.workers[worker_id].is_alive(): - continue + if not self.task_feeding_worker.is_alive(): + exitcode = self.task_feeding_worker.exitcode + if exitcode != 0: + raise RuntimeError("task feeding worker died. {}".format(exitcode)) + + for worker_id, worker in enumerate(self.workers): + if not worker.is_alive(): exitcode = worker.exitcode - if exitcode == 0: - continue - logger.error("worker {} died. Restarting".format(worker_id)) - failed_num += 1 - self.workers[worker_id].join() - worker = multiprocessing.Process( - target=self._worker_loop, - args=( - self.task_queues[worker_id], - self.batch_part_queues[worker_id], - self.transform, - self.collator, - self.seed + worker_id + 1, - ), - daemon=True, - ) - worker.start() - self.workers[worker_id] = worker + if exitcode != 0: + raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode)) - if failed_num > 0: - logger.error("{} worker had exited".format(failed_num)) - else: - logger.debug("all workers are alive.") + logger.debug("all workers are alive.") def _try_get_next_batch(self): start_time = time.time() @@ -530,7 +304,7 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): worker.terminate() worker.join() - for q in self.batch_part_queues: + for q in self.trans_data_queues: q.cancel_join_thread() q.close() @@ -544,3 +318,183 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): def __del__(self): if self.__initialized: self._shutdown() + + +def _task_feeding_loop( + indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx +): + # Feed the indices into the task queues + while True: + if shutdown_flag.value == 1: + break + batch_idx = feed_batch_idx.value + try: + indices = next(indices_iter) + except StopIteration: + break + if divide: + # make sure all task_queues is ready for put + while any([q.full() for q in task_queues]): + if shutdown_flag.value == 1: + return + # divide into small pieces, feed to different workers. + sub_num = math.ceil(len(indices) / num_workers) + for worker_id in range(num_workers): + sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num] + task_queues[worker_id].put((batch_idx, sub_indices)) + else: + # distribute tasks to different workers uniformly. + target_id = batch_idx % num_workers + while task_queues[target_id].full(): + if shutdown_flag.value == 1: + return + task_queues[target_id].put((batch_idx, indices)) + with feed_batch_idx.get_lock(): + feed_batch_idx.value += 1 + + +def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag): + # Get dataset items and do the transform + random.seed(seed) + np.random.seed(seed) + while True: + if shutdown_flag.value == 1: + break + try: + batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT) + except queue.Empty: + continue + if len(indices) > 0: + items = [dataset[idx] for idx in indices] + trans_items = transform.apply_batch(items) + else: + # in case of incomplete last batch + trans_items = () + while True: + try: + trans_data_queue.put((batch_idx, trans_items), timeout=1) + break + except queue.Full: + if shutdown_flag.value == 1: + break + logger.debug("batch part queue is full!") + + +def _data_gathering_loop( + trans_data_queues, + batch_queue, + collator, + length, + num_workers, + shutdown_flag, + target_idx, +): + # Gathering the small pieces of batch data into full batch data + while True: + if shutdown_flag.value == 1: + break + + target_batch_idx = target_idx.value + + if target_batch_idx >= length: + break + + full_trans_items = [] + for worker_id in range(num_workers): + while True: + try: + batch_idx, trans_items = trans_data_queues[worker_id].get( + timeout=MP_QUEUE_GET_TIMEOUT + ) + break + except queue.Empty: + if shutdown_flag.value == 1: + break + logger.debug( + "worker:{} data queue get timeout! target batch idx:{}".format( + worker_id, target_batch_idx + ) + ) + if batch_idx != target_batch_idx: + raise RuntimeError( + "Unexperted batch_idx in data gathering loop. worker_id:{}.".format( + worker_id + ) + ) + else: + full_trans_items.extend(trans_items) + + # Merge different parts into a batch. + full_batch = collator.apply(full_trans_items) + + while True: + try: + batch_queue.put(full_batch, timeout=1) + break + except queue.Full: + if shutdown_flag.value == 1: + break + logger.debug("batch queue is full!") + + with target_idx.get_lock(): + target_idx.value += 1 + + batch_queue.disconnect_client() + + +def _data_selecting_loop( + trans_data_queues, + batch_queue, + collator, + length, + num_workers, + shutdown_flag, + target_idx, +): + # Make sure that batch is generated exactly with the same order as generated indices + while True: + if shutdown_flag.value == 1: + break + + target_batch_idx = target_idx.value + + if target_batch_idx >= length: + break + + target_worker_id = target_batch_idx % num_workers + while True: + try: + batch_idx, trans_items = trans_data_queues[target_worker_id].get( + timeout=MP_QUEUE_GET_TIMEOUT + ) + batch_data = collator.apply(trans_items) + break + except queue.Empty: + if shutdown_flag.value == 1: + break + logger.debug( + "worker:{} data queue get timeout! target batch idx:{}".format( + target_worker_id, target_batch_idx + ) + ) + + if batch_idx != target_batch_idx: + raise RuntimeError( + "batch_idx {} mismatch the target_batch_idx {}".format( + batch_idx, target_batch_idx + ) + ) + + while True: + try: + batch_queue.put(batch_data, timeout=1) + break + except queue.Full: + if shutdown_flag.value == 1: + break + logger.debug("batch queue is full!") + + with target_idx.get_lock(): + target_idx.value += 1 + + batch_queue.disconnect_client()