Browse Source

refactor(mge/data/dataloader): refactor the implementation of parallel dataloader

GitOrigin-RevId: 0554ee8427
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
2c2caf3331
1 changed files with 228 additions and 274 deletions
  1. +228
    -274
      python_module/megengine/data/dataloader.py

+ 228
- 274
python_module/megengine/data/dataloader.py View File

@@ -15,9 +15,8 @@ import time


import numpy as np import numpy as np


import megengine as mge

from ..logger import get_logger from ..logger import get_logger
from ..random.rng import _random_seed_generator
from .collator import Collator from .collator import Collator
from .dataset import Dataset from .dataset import Dataset
from .sampler import Sampler, SequentialSampler from .sampler import Sampler, SequentialSampler
@@ -87,8 +86,6 @@ class DataLoader:


self.divide = divide self.divide = divide


self.rng = np.random.RandomState()

if sampler is None: if sampler is None:
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False)
else: else:
@@ -130,7 +127,7 @@ class _BaseDataLoaderIter:
def __init__(self, loader): def __init__(self, loader):
self.dataset = loader.dataset self.dataset = loader.dataset
self.sampler = loader.sampler self.sampler = loader.sampler
self.seed = loader.rng.randint(1e9)
self.seed = _random_seed_generator().__next__()
self.transform = loader.transform self.transform = loader.transform
self.collator = loader.collator self.collator = loader.collator
self.num_workers = loader.num_workers self.num_workers = loader.num_workers
@@ -173,10 +170,6 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader): def __init__(self, loader):
super(_ParallelDataLoaderIter, self).__init__(loader) super(_ParallelDataLoaderIter, self).__init__(loader)


# if any worker died, all workers will be shutdown.
self.strict = True
# TODO: put `strict` into DataLoader args or not?

self.task_queues = [ self.task_queues = [
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
] ]
@@ -185,7 +178,7 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self.target_batch_idx = multiprocessing.Value("i", 0) self.target_batch_idx = multiprocessing.Value("i", 0)
self.shutdown_flag = multiprocessing.Value("i", 0) self.shutdown_flag = multiprocessing.Value("i", 0)


self.batch_part_queues = [
self.trans_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
] ]


@@ -195,8 +188,15 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self.batch_queue = PlasmaShmQueue(maxsize=2) self.batch_queue = PlasmaShmQueue(maxsize=2)


self.task_feeding_worker = multiprocessing.Process( self.task_feeding_worker = multiprocessing.Process(
target=self._task_feeding_loop,
args=(iter(self.sampler), self.divide),
target=_task_feeding_loop,
args=(
iter(self.sampler),
self.task_queues,
self.num_workers,
self.divide,
self.shutdown_flag,
self.feed_batch_idx,
),
daemon=True, daemon=True,
) )
self.task_feeding_worker.start() self.task_feeding_worker.start()
@@ -204,13 +204,14 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self.workers = [] self.workers = []
for worker_id in range(self.num_workers): for worker_id in range(self.num_workers):
worker = multiprocessing.Process( worker = multiprocessing.Process(
target=self._worker_loop,
target=_worker_loop,
args=( args=(
self.dataset,
self.task_queues[worker_id], self.task_queues[worker_id],
self.batch_part_queues[worker_id],
self.trans_data_queues[worker_id],
self.transform, self.transform,
self.collator,
self.seed + worker_id + 1, self.seed + worker_id + 1,
self.shutdown_flag,
), ),
daemon=True, daemon=True,
) )
@@ -219,282 +220,55 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):


if self.divide: if self.divide:
self.data_collecting_worker = multiprocessing.Process( self.data_collecting_worker = multiprocessing.Process(
target=self._data_gathering_loop,
args=(self.batch_part_queues, self.batch_queue,),
target=_data_gathering_loop,
args=(
self.trans_data_queues,
self.batch_queue,
self.collator,
len(self),
self.num_workers,
self.shutdown_flag,
self.target_batch_idx,
),
daemon=True, daemon=True,
) )
else: else:
self.data_collecting_worker = multiprocessing.Process( self.data_collecting_worker = multiprocessing.Process(
target=self._data_selecting_loop,
args=(self.batch_part_queues, self.batch_queue,),
target=_data_selecting_loop,
args=(
self.trans_data_queues,
self.batch_queue,
self.collator,
len(self),
self.num_workers,
self.shutdown_flag,
self.target_batch_idx,
),
daemon=True, daemon=True,
) )
self.data_collecting_worker.start() self.data_collecting_worker.start()


self.__initialized = True self.__initialized = True


def _task_feeding_loop(self, indices_iter, divide):
while True:
if self.shutdown_flag.value == 1:
break
batch_idx = self.feed_batch_idx.value
try:
indices = next(indices_iter)
except StopIteration:
break
if divide:
# make sure all task_queues is ready for put
while any([q.full() for q in self.task_queues]):
if self.shutdown_flag.value == 1:
return
# divide into small pieces, feed to different workers.
sub_num = math.ceil(len(indices) / self.num_workers)
for worker_id in range(self.num_workers):
sub_indices = indices[
worker_id * sub_num : (worker_id + 1) * sub_num
]
self.task_queues[worker_id].put((batch_idx, sub_indices))
else:
# distribute tasks to different workers uniformly.
target_id = batch_idx % self.num_workers
while self.task_queues[target_id].full():
if self.shutdown_flag.value == 1:
return
self.task_queues[target_id].put((batch_idx, indices))
with self.feed_batch_idx.get_lock():
self.feed_batch_idx.value += 1

def _worker_loop(self, task_queue, data_queue, transform, collator, seed):
random.seed(seed)
np.random.seed(seed)
while True:
if self.shutdown_flag.value == 1:
break
try:
batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT)
except queue.Empty:
continue
if len(indices) > 0:
items = [self.dataset[idx] for idx in indices]
trans_items = transform.apply_batch(items)
batch_data = collator.apply(trans_items)
else:
# in case of incomplete last batch
batch_data = ()
while True:
try:
data_queue.put((np.array([batch_idx]), batch_data), timeout=1)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch part queue is full!")
continue

def _data_gathering_loop(self, batch_part_queues, batch_queue):
r"""Gathering the small pieces of batch data into full batch data."""
gathered_data = collections.defaultdict(dict)
while True:
if self.shutdown_flag.value == 1:
break

target_batch_idx = self.target_batch_idx.value

if target_batch_idx >= len(self):
break

for worker_id in range(self.num_workers):
if worker_id in gathered_data[target_batch_idx]:
continue
while True:
try:
(batch_idx,), batch_part = batch_part_queues[worker_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
)
break
except queue.Empty:
if self.shutdown_flag.value == 1:
break
logger.debug(
"worker:{} data queue get timeout! target batch idx:{}".format(
worker_id, target_batch_idx
)
)
if batch_idx < target_batch_idx:
raise RuntimeError(
"Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
worker_id
)
)
else:
gathered_data[batch_idx][worker_id] = batch_part

if len(gathered_data[target_batch_idx]) < self.num_workers:
length = len(gathered_data[target_batch_idx])
if self.strict:
raise RuntimeError("Parts missing in data gathering loop.")
logger.warning(
"target_batch_idx:{}, {} part(s) missing.".format(
target_batch_idx, self.num_workers - length
)
)
del gathered_data[target_batch_idx]
with self.target_batch_idx.get_lock():
self.target_batch_idx.value += 1
continue

# Merge different parts.
full_batch = [[] for _ in range(len(gathered_data[target_batch_idx][0]))]
for idx in range(self.num_workers):
for i, field in enumerate(gathered_data[target_batch_idx][idx]):
full_batch[i].append(field)
full_batch = tuple([np.concatenate(field, axis=0) for field in full_batch])

while True:
try:
batch_queue.put(full_batch, timeout=1)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue is full!")
continue

del gathered_data[target_batch_idx]

with self.target_batch_idx.get_lock():
self.target_batch_idx.value += 1

batch_queue.disconnect_client()

def _data_selecting_loop(self, batch_part_queues, batch_queue):
r"""Make sure that batch is generated exactly with the same order as generated indices."""
buffer_batches = {}
while True:
if self.shutdown_flag.value == 1:
break

target_batch_idx = self.target_batch_idx.value

if target_batch_idx >= len(self):
break

if target_batch_idx in buffer_batches:
while True:
try:
batch_queue.put(
buffer_batches[target_batch_idx], timeout=1,
)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue is full!")
with self.target_batch_idx.get_lock():
self.target_batch_idx.value += 1
del buffer_batches[target_batch_idx]
continue

target_worker_id = target_batch_idx % self.num_workers
while True:
try:
(batch_idx,), batch_data = batch_part_queues[target_worker_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
)
break
except queue.Empty:
if self.shutdown_flag.value == 1:
break
logger.debug(
"worker:{} data queue get timeout! target batch idx:{}".format(
target_worker_id, target_batch_idx
)
)

if batch_idx < target_batch_idx:
raise RuntimeError("batch_idx smaller than target_batch_idx")
elif batch_idx > target_batch_idx:
if self.strict:
raise RuntimeError("batch_idx larger than target_batch_idx")
logger.warning(
"missing target batch idx:{}, batch idx:{}".format(
target_batch_idx, batch_idx
)
)
buffer_batches[batch_idx] = batch_data
else:
try:
batch_queue.put(batch_data, timeout=1)
except queue.Full:
buffer_batches[batch_idx] = batch_data
continue

with self.target_batch_idx.get_lock():
self.target_batch_idx.value += 1

batch_queue.disconnect_client()

def _check_workers(self): def _check_workers(self):
"""Check the status of each worker and restart if necessary."""
# Check the status of each worker.
if not self.data_collecting_worker.is_alive(): if not self.data_collecting_worker.is_alive():
exitcode = self.task_feeding_worker.exitcode exitcode = self.task_feeding_worker.exitcode
if exitcode != 0: if exitcode != 0:
raise RuntimeError("data collecting worker died. {}".format(exitcode)) raise RuntimeError("data collecting worker died. {}".format(exitcode))
if self.strict:
if not self.task_feeding_worker.is_alive():
exitcode = self.task_feeding_worker.exitcode
if exitcode != 0:
raise RuntimeError("task feeding worker died. {}".format(exitcode))
for worker_id, worker in enumerate(self.workers):
if not worker.is_alive():
exitcode = worker.exitcode
if exitcode != 0:
raise RuntimeError(
"worker:{} died. {}".format(worker_id, exitcode)
)
else:
if not self.task_feeding_worker.is_alive():
exitcode = self.task_feeding_worker.exitcode
if exitcode != 0:
logger.error(
"task feeding worker died {}. Restarting".format(exitcode)
)
self.task_feeding_worker.join()
self.task_feeding_worker = multiprocessing.Process(
target=self._task_feeding_loop,
args=(iter(self.sampler), self.divide),
daemon=True,
)
self.task_feeding_worker.start()


failed_num = 0
for worker_id in range(self.num_workers):
if self.workers[worker_id].is_alive():
continue
if not self.task_feeding_worker.is_alive():
exitcode = self.task_feeding_worker.exitcode
if exitcode != 0:
raise RuntimeError("task feeding worker died. {}".format(exitcode))

for worker_id, worker in enumerate(self.workers):
if not worker.is_alive():
exitcode = worker.exitcode exitcode = worker.exitcode
if exitcode == 0:
continue
logger.error("worker {} died. Restarting".format(worker_id))
failed_num += 1
self.workers[worker_id].join()
worker = multiprocessing.Process(
target=self._worker_loop,
args=(
self.task_queues[worker_id],
self.batch_part_queues[worker_id],
self.transform,
self.collator,
self.seed + worker_id + 1,
),
daemon=True,
)
worker.start()
self.workers[worker_id] = worker
if exitcode != 0:
raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))


if failed_num > 0:
logger.error("{} worker had exited".format(failed_num))
else:
logger.debug("all workers are alive.")
logger.debug("all workers are alive.")


def _try_get_next_batch(self): def _try_get_next_batch(self):
start_time = time.time() start_time = time.time()
@@ -530,7 +304,7 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
worker.terminate() worker.terminate()
worker.join() worker.join()


for q in self.batch_part_queues:
for q in self.trans_data_queues:
q.cancel_join_thread() q.cancel_join_thread()
q.close() q.close()


@@ -544,3 +318,183 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
def __del__(self): def __del__(self):
if self.__initialized: if self.__initialized:
self._shutdown() self._shutdown()


def _task_feeding_loop(
indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
):
# Feed the indices into the task queues
while True:
if shutdown_flag.value == 1:
break
batch_idx = feed_batch_idx.value
try:
indices = next(indices_iter)
except StopIteration:
break
if divide:
# make sure all task_queues is ready for put
while any([q.full() for q in task_queues]):
if shutdown_flag.value == 1:
return
# divide into small pieces, feed to different workers.
sub_num = math.ceil(len(indices) / num_workers)
for worker_id in range(num_workers):
sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
task_queues[worker_id].put((batch_idx, sub_indices))
else:
# distribute tasks to different workers uniformly.
target_id = batch_idx % num_workers
while task_queues[target_id].full():
if shutdown_flag.value == 1:
return
task_queues[target_id].put((batch_idx, indices))
with feed_batch_idx.get_lock():
feed_batch_idx.value += 1


def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
# Get dataset items and do the transform
random.seed(seed)
np.random.seed(seed)
while True:
if shutdown_flag.value == 1:
break
try:
batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT)
except queue.Empty:
continue
if len(indices) > 0:
items = [dataset[idx] for idx in indices]
trans_items = transform.apply_batch(items)
else:
# in case of incomplete last batch
trans_items = ()
while True:
try:
trans_data_queue.put((batch_idx, trans_items), timeout=1)
break
except queue.Full:
if shutdown_flag.value == 1:
break
logger.debug("batch part queue is full!")


def _data_gathering_loop(
trans_data_queues,
batch_queue,
collator,
length,
num_workers,
shutdown_flag,
target_idx,
):
# Gathering the small pieces of batch data into full batch data
while True:
if shutdown_flag.value == 1:
break

target_batch_idx = target_idx.value

if target_batch_idx >= length:
break

full_trans_items = []
for worker_id in range(num_workers):
while True:
try:
batch_idx, trans_items = trans_data_queues[worker_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
)
break
except queue.Empty:
if shutdown_flag.value == 1:
break
logger.debug(
"worker:{} data queue get timeout! target batch idx:{}".format(
worker_id, target_batch_idx
)
)
if batch_idx != target_batch_idx:
raise RuntimeError(
"Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
worker_id
)
)
else:
full_trans_items.extend(trans_items)

# Merge different parts into a batch.
full_batch = collator.apply(full_trans_items)

while True:
try:
batch_queue.put(full_batch, timeout=1)
break
except queue.Full:
if shutdown_flag.value == 1:
break
logger.debug("batch queue is full!")

with target_idx.get_lock():
target_idx.value += 1

batch_queue.disconnect_client()


def _data_selecting_loop(
trans_data_queues,
batch_queue,
collator,
length,
num_workers,
shutdown_flag,
target_idx,
):
# Make sure that batch is generated exactly with the same order as generated indices
while True:
if shutdown_flag.value == 1:
break

target_batch_idx = target_idx.value

if target_batch_idx >= length:
break

target_worker_id = target_batch_idx % num_workers
while True:
try:
batch_idx, trans_items = trans_data_queues[target_worker_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
)
batch_data = collator.apply(trans_items)
break
except queue.Empty:
if shutdown_flag.value == 1:
break
logger.debug(
"worker:{} data queue get timeout! target batch idx:{}".format(
target_worker_id, target_batch_idx
)
)

if batch_idx != target_batch_idx:
raise RuntimeError(
"batch_idx {} mismatch the target_batch_idx {}".format(
batch_idx, target_batch_idx
)
)

while True:
try:
batch_queue.put(batch_data, timeout=1)
break
except queue.Full:
if shutdown_flag.value == 1:
break
logger.debug("batch queue is full!")

with target_idx.get_lock():
target_idx.value += 1

batch_queue.disconnect_client()

Loading…
Cancel
Save