|
@@ -15,9 +15,8 @@ import time |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
import megengine as mge |
|
|
|
|
|
|
|
|
|
|
|
from ..logger import get_logger |
|
|
from ..logger import get_logger |
|
|
|
|
|
from ..random.rng import _random_seed_generator |
|
|
from .collator import Collator |
|
|
from .collator import Collator |
|
|
from .dataset import Dataset |
|
|
from .dataset import Dataset |
|
|
from .sampler import Sampler, SequentialSampler |
|
|
from .sampler import Sampler, SequentialSampler |
|
@@ -87,8 +86,6 @@ class DataLoader: |
|
|
|
|
|
|
|
|
self.divide = divide |
|
|
self.divide = divide |
|
|
|
|
|
|
|
|
self.rng = np.random.RandomState() |
|
|
|
|
|
|
|
|
|
|
|
if sampler is None: |
|
|
if sampler is None: |
|
|
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) |
|
|
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) |
|
|
else: |
|
|
else: |
|
@@ -130,7 +127,7 @@ class _BaseDataLoaderIter: |
|
|
def __init__(self, loader): |
|
|
def __init__(self, loader): |
|
|
self.dataset = loader.dataset |
|
|
self.dataset = loader.dataset |
|
|
self.sampler = loader.sampler |
|
|
self.sampler = loader.sampler |
|
|
self.seed = loader.rng.randint(1e9) |
|
|
|
|
|
|
|
|
self.seed = _random_seed_generator().__next__() |
|
|
self.transform = loader.transform |
|
|
self.transform = loader.transform |
|
|
self.collator = loader.collator |
|
|
self.collator = loader.collator |
|
|
self.num_workers = loader.num_workers |
|
|
self.num_workers = loader.num_workers |
|
@@ -173,10 +170,6 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): |
|
|
def __init__(self, loader): |
|
|
def __init__(self, loader): |
|
|
super(_ParallelDataLoaderIter, self).__init__(loader) |
|
|
super(_ParallelDataLoaderIter, self).__init__(loader) |
|
|
|
|
|
|
|
|
# if any worker died, all workers will be shutdown. |
|
|
|
|
|
self.strict = True |
|
|
|
|
|
# TODO: put `strict` into DataLoader args or not? |
|
|
|
|
|
|
|
|
|
|
|
self.task_queues = [ |
|
|
self.task_queues = [ |
|
|
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) |
|
|
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) |
|
|
] |
|
|
] |
|
@@ -185,7 +178,7 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): |
|
|
self.target_batch_idx = multiprocessing.Value("i", 0) |
|
|
self.target_batch_idx = multiprocessing.Value("i", 0) |
|
|
self.shutdown_flag = multiprocessing.Value("i", 0) |
|
|
self.shutdown_flag = multiprocessing.Value("i", 0) |
|
|
|
|
|
|
|
|
self.batch_part_queues = [ |
|
|
|
|
|
|
|
|
self.trans_data_queues = [ |
|
|
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) |
|
|
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) |
|
|
] |
|
|
] |
|
|
|
|
|
|
|
@@ -195,8 +188,15 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): |
|
|
self.batch_queue = PlasmaShmQueue(maxsize=2) |
|
|
self.batch_queue = PlasmaShmQueue(maxsize=2) |
|
|
|
|
|
|
|
|
self.task_feeding_worker = multiprocessing.Process( |
|
|
self.task_feeding_worker = multiprocessing.Process( |
|
|
target=self._task_feeding_loop, |
|
|
|
|
|
args=(iter(self.sampler), self.divide), |
|
|
|
|
|
|
|
|
target=_task_feeding_loop, |
|
|
|
|
|
args=( |
|
|
|
|
|
iter(self.sampler), |
|
|
|
|
|
self.task_queues, |
|
|
|
|
|
self.num_workers, |
|
|
|
|
|
self.divide, |
|
|
|
|
|
self.shutdown_flag, |
|
|
|
|
|
self.feed_batch_idx, |
|
|
|
|
|
), |
|
|
daemon=True, |
|
|
daemon=True, |
|
|
) |
|
|
) |
|
|
self.task_feeding_worker.start() |
|
|
self.task_feeding_worker.start() |
|
@@ -204,13 +204,14 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): |
|
|
self.workers = [] |
|
|
self.workers = [] |
|
|
for worker_id in range(self.num_workers): |
|
|
for worker_id in range(self.num_workers): |
|
|
worker = multiprocessing.Process( |
|
|
worker = multiprocessing.Process( |
|
|
target=self._worker_loop, |
|
|
|
|
|
|
|
|
target=_worker_loop, |
|
|
args=( |
|
|
args=( |
|
|
|
|
|
self.dataset, |
|
|
self.task_queues[worker_id], |
|
|
self.task_queues[worker_id], |
|
|
self.batch_part_queues[worker_id], |
|
|
|
|
|
|
|
|
self.trans_data_queues[worker_id], |
|
|
self.transform, |
|
|
self.transform, |
|
|
self.collator, |
|
|
|
|
|
self.seed + worker_id + 1, |
|
|
self.seed + worker_id + 1, |
|
|
|
|
|
self.shutdown_flag, |
|
|
), |
|
|
), |
|
|
daemon=True, |
|
|
daemon=True, |
|
|
) |
|
|
) |
|
@@ -219,282 +220,55 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): |
|
|
|
|
|
|
|
|
if self.divide: |
|
|
if self.divide: |
|
|
self.data_collecting_worker = multiprocessing.Process( |
|
|
self.data_collecting_worker = multiprocessing.Process( |
|
|
target=self._data_gathering_loop, |
|
|
|
|
|
args=(self.batch_part_queues, self.batch_queue,), |
|
|
|
|
|
|
|
|
target=_data_gathering_loop, |
|
|
|
|
|
args=( |
|
|
|
|
|
self.trans_data_queues, |
|
|
|
|
|
self.batch_queue, |
|
|
|
|
|
self.collator, |
|
|
|
|
|
len(self), |
|
|
|
|
|
self.num_workers, |
|
|
|
|
|
self.shutdown_flag, |
|
|
|
|
|
self.target_batch_idx, |
|
|
|
|
|
), |
|
|
daemon=True, |
|
|
daemon=True, |
|
|
) |
|
|
) |
|
|
else: |
|
|
else: |
|
|
self.data_collecting_worker = multiprocessing.Process( |
|
|
self.data_collecting_worker = multiprocessing.Process( |
|
|
target=self._data_selecting_loop, |
|
|
|
|
|
args=(self.batch_part_queues, self.batch_queue,), |
|
|
|
|
|
|
|
|
target=_data_selecting_loop, |
|
|
|
|
|
args=( |
|
|
|
|
|
self.trans_data_queues, |
|
|
|
|
|
self.batch_queue, |
|
|
|
|
|
self.collator, |
|
|
|
|
|
len(self), |
|
|
|
|
|
self.num_workers, |
|
|
|
|
|
self.shutdown_flag, |
|
|
|
|
|
self.target_batch_idx, |
|
|
|
|
|
), |
|
|
daemon=True, |
|
|
daemon=True, |
|
|
) |
|
|
) |
|
|
self.data_collecting_worker.start() |
|
|
self.data_collecting_worker.start() |
|
|
|
|
|
|
|
|
self.__initialized = True |
|
|
self.__initialized = True |
|
|
|
|
|
|
|
|
def _task_feeding_loop(self, indices_iter, divide): |
|
|
|
|
|
while True: |
|
|
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
batch_idx = self.feed_batch_idx.value |
|
|
|
|
|
try: |
|
|
|
|
|
indices = next(indices_iter) |
|
|
|
|
|
except StopIteration: |
|
|
|
|
|
break |
|
|
|
|
|
if divide: |
|
|
|
|
|
# make sure all task_queues is ready for put |
|
|
|
|
|
while any([q.full() for q in self.task_queues]): |
|
|
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
|
|
return |
|
|
|
|
|
# divide into small pieces, feed to different workers. |
|
|
|
|
|
sub_num = math.ceil(len(indices) / self.num_workers) |
|
|
|
|
|
for worker_id in range(self.num_workers): |
|
|
|
|
|
sub_indices = indices[ |
|
|
|
|
|
worker_id * sub_num : (worker_id + 1) * sub_num |
|
|
|
|
|
] |
|
|
|
|
|
self.task_queues[worker_id].put((batch_idx, sub_indices)) |
|
|
|
|
|
else: |
|
|
|
|
|
# distribute tasks to different workers uniformly. |
|
|
|
|
|
target_id = batch_idx % self.num_workers |
|
|
|
|
|
while self.task_queues[target_id].full(): |
|
|
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
|
|
return |
|
|
|
|
|
self.task_queues[target_id].put((batch_idx, indices)) |
|
|
|
|
|
with self.feed_batch_idx.get_lock(): |
|
|
|
|
|
self.feed_batch_idx.value += 1 |
|
|
|
|
|
|
|
|
|
|
|
def _worker_loop(self, task_queue, data_queue, transform, collator, seed): |
|
|
|
|
|
random.seed(seed) |
|
|
|
|
|
np.random.seed(seed) |
|
|
|
|
|
while True: |
|
|
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
try: |
|
|
|
|
|
batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT) |
|
|
|
|
|
except queue.Empty: |
|
|
|
|
|
continue |
|
|
|
|
|
if len(indices) > 0: |
|
|
|
|
|
items = [self.dataset[idx] for idx in indices] |
|
|
|
|
|
trans_items = transform.apply_batch(items) |
|
|
|
|
|
batch_data = collator.apply(trans_items) |
|
|
|
|
|
else: |
|
|
|
|
|
# in case of incomplete last batch |
|
|
|
|
|
batch_data = () |
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
|
|
|
data_queue.put((np.array([batch_idx]), batch_data), timeout=1) |
|
|
|
|
|
break |
|
|
|
|
|
except queue.Full: |
|
|
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
logger.debug("batch part queue is full!") |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
def _data_gathering_loop(self, batch_part_queues, batch_queue): |
|
|
|
|
|
r"""Gathering the small pieces of batch data into full batch data.""" |
|
|
|
|
|
gathered_data = collections.defaultdict(dict) |
|
|
|
|
|
while True: |
|
|
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
target_batch_idx = self.target_batch_idx.value |
|
|
|
|
|
|
|
|
|
|
|
if target_batch_idx >= len(self): |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
for worker_id in range(self.num_workers): |
|
|
|
|
|
if worker_id in gathered_data[target_batch_idx]: |
|
|
|
|
|
continue |
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
|
|
|
(batch_idx,), batch_part = batch_part_queues[worker_id].get( |
|
|
|
|
|
timeout=MP_QUEUE_GET_TIMEOUT |
|
|
|
|
|
) |
|
|
|
|
|
break |
|
|
|
|
|
except queue.Empty: |
|
|
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
logger.debug( |
|
|
|
|
|
"worker:{} data queue get timeout! target batch idx:{}".format( |
|
|
|
|
|
worker_id, target_batch_idx |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
if batch_idx < target_batch_idx: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
"Unexperted batch_idx in data gathering loop. worker_id:{}.".format( |
|
|
|
|
|
worker_id |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
gathered_data[batch_idx][worker_id] = batch_part |
|
|
|
|
|
|
|
|
|
|
|
if len(gathered_data[target_batch_idx]) < self.num_workers: |
|
|
|
|
|
length = len(gathered_data[target_batch_idx]) |
|
|
|
|
|
if self.strict: |
|
|
|
|
|
raise RuntimeError("Parts missing in data gathering loop.") |
|
|
|
|
|
logger.warning( |
|
|
|
|
|
"target_batch_idx:{}, {} part(s) missing.".format( |
|
|
|
|
|
target_batch_idx, self.num_workers - length |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
del gathered_data[target_batch_idx] |
|
|
|
|
|
with self.target_batch_idx.get_lock(): |
|
|
|
|
|
self.target_batch_idx.value += 1 |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
# Merge different parts. |
|
|
|
|
|
full_batch = [[] for _ in range(len(gathered_data[target_batch_idx][0]))] |
|
|
|
|
|
for idx in range(self.num_workers): |
|
|
|
|
|
for i, field in enumerate(gathered_data[target_batch_idx][idx]): |
|
|
|
|
|
full_batch[i].append(field) |
|
|
|
|
|
full_batch = tuple([np.concatenate(field, axis=0) for field in full_batch]) |
|
|
|
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
|
|
|
batch_queue.put(full_batch, timeout=1) |
|
|
|
|
|
break |
|
|
|
|
|
except queue.Full: |
|
|
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
logger.debug("batch queue is full!") |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
del gathered_data[target_batch_idx] |
|
|
|
|
|
|
|
|
|
|
|
with self.target_batch_idx.get_lock(): |
|
|
|
|
|
self.target_batch_idx.value += 1 |
|
|
|
|
|
|
|
|
|
|
|
batch_queue.disconnect_client() |
|
|
|
|
|
|
|
|
|
|
|
def _data_selecting_loop(self, batch_part_queues, batch_queue): |
|
|
|
|
|
r"""Make sure that batch is generated exactly with the same order as generated indices.""" |
|
|
|
|
|
buffer_batches = {} |
|
|
|
|
|
while True: |
|
|
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
target_batch_idx = self.target_batch_idx.value |
|
|
|
|
|
|
|
|
|
|
|
if target_batch_idx >= len(self): |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
if target_batch_idx in buffer_batches: |
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
|
|
|
batch_queue.put( |
|
|
|
|
|
buffer_batches[target_batch_idx], timeout=1, |
|
|
|
|
|
) |
|
|
|
|
|
break |
|
|
|
|
|
except queue.Full: |
|
|
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
logger.debug("batch queue is full!") |
|
|
|
|
|
with self.target_batch_idx.get_lock(): |
|
|
|
|
|
self.target_batch_idx.value += 1 |
|
|
|
|
|
del buffer_batches[target_batch_idx] |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
target_worker_id = target_batch_idx % self.num_workers |
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
|
|
|
(batch_idx,), batch_data = batch_part_queues[target_worker_id].get( |
|
|
|
|
|
timeout=MP_QUEUE_GET_TIMEOUT |
|
|
|
|
|
) |
|
|
|
|
|
break |
|
|
|
|
|
except queue.Empty: |
|
|
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
logger.debug( |
|
|
|
|
|
"worker:{} data queue get timeout! target batch idx:{}".format( |
|
|
|
|
|
target_worker_id, target_batch_idx |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if batch_idx < target_batch_idx: |
|
|
|
|
|
raise RuntimeError("batch_idx smaller than target_batch_idx") |
|
|
|
|
|
elif batch_idx > target_batch_idx: |
|
|
|
|
|
if self.strict: |
|
|
|
|
|
raise RuntimeError("batch_idx larger than target_batch_idx") |
|
|
|
|
|
logger.warning( |
|
|
|
|
|
"missing target batch idx:{}, batch idx:{}".format( |
|
|
|
|
|
target_batch_idx, batch_idx |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
buffer_batches[batch_idx] = batch_data |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
|
|
|
batch_queue.put(batch_data, timeout=1) |
|
|
|
|
|
except queue.Full: |
|
|
|
|
|
buffer_batches[batch_idx] = batch_data |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
with self.target_batch_idx.get_lock(): |
|
|
|
|
|
self.target_batch_idx.value += 1 |
|
|
|
|
|
|
|
|
|
|
|
batch_queue.disconnect_client() |
|
|
|
|
|
|
|
|
|
|
|
def _check_workers(self): |
|
|
def _check_workers(self): |
|
|
"""Check the status of each worker and restart if necessary.""" |
|
|
|
|
|
|
|
|
# Check the status of each worker. |
|
|
if not self.data_collecting_worker.is_alive(): |
|
|
if not self.data_collecting_worker.is_alive(): |
|
|
exitcode = self.task_feeding_worker.exitcode |
|
|
exitcode = self.task_feeding_worker.exitcode |
|
|
if exitcode != 0: |
|
|
if exitcode != 0: |
|
|
raise RuntimeError("data collecting worker died. {}".format(exitcode)) |
|
|
raise RuntimeError("data collecting worker died. {}".format(exitcode)) |
|
|
if self.strict: |
|
|
|
|
|
if not self.task_feeding_worker.is_alive(): |
|
|
|
|
|
exitcode = self.task_feeding_worker.exitcode |
|
|
|
|
|
if exitcode != 0: |
|
|
|
|
|
raise RuntimeError("task feeding worker died. {}".format(exitcode)) |
|
|
|
|
|
for worker_id, worker in enumerate(self.workers): |
|
|
|
|
|
if not worker.is_alive(): |
|
|
|
|
|
exitcode = worker.exitcode |
|
|
|
|
|
if exitcode != 0: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
"worker:{} died. {}".format(worker_id, exitcode) |
|
|
|
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
if not self.task_feeding_worker.is_alive(): |
|
|
|
|
|
exitcode = self.task_feeding_worker.exitcode |
|
|
|
|
|
if exitcode != 0: |
|
|
|
|
|
logger.error( |
|
|
|
|
|
"task feeding worker died {}. Restarting".format(exitcode) |
|
|
|
|
|
) |
|
|
|
|
|
self.task_feeding_worker.join() |
|
|
|
|
|
self.task_feeding_worker = multiprocessing.Process( |
|
|
|
|
|
target=self._task_feeding_loop, |
|
|
|
|
|
args=(iter(self.sampler), self.divide), |
|
|
|
|
|
daemon=True, |
|
|
|
|
|
) |
|
|
|
|
|
self.task_feeding_worker.start() |
|
|
|
|
|
|
|
|
|
|
|
failed_num = 0 |
|
|
|
|
|
for worker_id in range(self.num_workers): |
|
|
|
|
|
if self.workers[worker_id].is_alive(): |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
if not self.task_feeding_worker.is_alive(): |
|
|
|
|
|
exitcode = self.task_feeding_worker.exitcode |
|
|
|
|
|
if exitcode != 0: |
|
|
|
|
|
raise RuntimeError("task feeding worker died. {}".format(exitcode)) |
|
|
|
|
|
|
|
|
|
|
|
for worker_id, worker in enumerate(self.workers): |
|
|
|
|
|
if not worker.is_alive(): |
|
|
exitcode = worker.exitcode |
|
|
exitcode = worker.exitcode |
|
|
if exitcode == 0: |
|
|
|
|
|
continue |
|
|
|
|
|
logger.error("worker {} died. Restarting".format(worker_id)) |
|
|
|
|
|
failed_num += 1 |
|
|
|
|
|
self.workers[worker_id].join() |
|
|
|
|
|
worker = multiprocessing.Process( |
|
|
|
|
|
target=self._worker_loop, |
|
|
|
|
|
args=( |
|
|
|
|
|
self.task_queues[worker_id], |
|
|
|
|
|
self.batch_part_queues[worker_id], |
|
|
|
|
|
self.transform, |
|
|
|
|
|
self.collator, |
|
|
|
|
|
self.seed + worker_id + 1, |
|
|
|
|
|
), |
|
|
|
|
|
daemon=True, |
|
|
|
|
|
) |
|
|
|
|
|
worker.start() |
|
|
|
|
|
self.workers[worker_id] = worker |
|
|
|
|
|
|
|
|
if exitcode != 0: |
|
|
|
|
|
raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode)) |
|
|
|
|
|
|
|
|
if failed_num > 0: |
|
|
|
|
|
logger.error("{} worker had exited".format(failed_num)) |
|
|
|
|
|
else: |
|
|
|
|
|
logger.debug("all workers are alive.") |
|
|
|
|
|
|
|
|
logger.debug("all workers are alive.") |
|
|
|
|
|
|
|
|
def _try_get_next_batch(self): |
|
|
def _try_get_next_batch(self): |
|
|
start_time = time.time() |
|
|
start_time = time.time() |
|
@@ -530,7 +304,7 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): |
|
|
worker.terminate() |
|
|
worker.terminate() |
|
|
worker.join() |
|
|
worker.join() |
|
|
|
|
|
|
|
|
for q in self.batch_part_queues: |
|
|
|
|
|
|
|
|
for q in self.trans_data_queues: |
|
|
q.cancel_join_thread() |
|
|
q.cancel_join_thread() |
|
|
q.close() |
|
|
q.close() |
|
|
|
|
|
|
|
@@ -544,3 +318,183 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): |
|
|
def __del__(self): |
|
|
def __del__(self): |
|
|
if self.__initialized: |
|
|
if self.__initialized: |
|
|
self._shutdown() |
|
|
self._shutdown() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _task_feeding_loop( |
|
|
|
|
|
indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx |
|
|
|
|
|
): |
|
|
|
|
|
# Feed the indices into the task queues |
|
|
|
|
|
while True: |
|
|
|
|
|
if shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
batch_idx = feed_batch_idx.value |
|
|
|
|
|
try: |
|
|
|
|
|
indices = next(indices_iter) |
|
|
|
|
|
except StopIteration: |
|
|
|
|
|
break |
|
|
|
|
|
if divide: |
|
|
|
|
|
# make sure all task_queues is ready for put |
|
|
|
|
|
while any([q.full() for q in task_queues]): |
|
|
|
|
|
if shutdown_flag.value == 1: |
|
|
|
|
|
return |
|
|
|
|
|
# divide into small pieces, feed to different workers. |
|
|
|
|
|
sub_num = math.ceil(len(indices) / num_workers) |
|
|
|
|
|
for worker_id in range(num_workers): |
|
|
|
|
|
sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num] |
|
|
|
|
|
task_queues[worker_id].put((batch_idx, sub_indices)) |
|
|
|
|
|
else: |
|
|
|
|
|
# distribute tasks to different workers uniformly. |
|
|
|
|
|
target_id = batch_idx % num_workers |
|
|
|
|
|
while task_queues[target_id].full(): |
|
|
|
|
|
if shutdown_flag.value == 1: |
|
|
|
|
|
return |
|
|
|
|
|
task_queues[target_id].put((batch_idx, indices)) |
|
|
|
|
|
with feed_batch_idx.get_lock(): |
|
|
|
|
|
feed_batch_idx.value += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag): |
|
|
|
|
|
# Get dataset items and do the transform |
|
|
|
|
|
random.seed(seed) |
|
|
|
|
|
np.random.seed(seed) |
|
|
|
|
|
while True: |
|
|
|
|
|
if shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
try: |
|
|
|
|
|
batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT) |
|
|
|
|
|
except queue.Empty: |
|
|
|
|
|
continue |
|
|
|
|
|
if len(indices) > 0: |
|
|
|
|
|
items = [dataset[idx] for idx in indices] |
|
|
|
|
|
trans_items = transform.apply_batch(items) |
|
|
|
|
|
else: |
|
|
|
|
|
# in case of incomplete last batch |
|
|
|
|
|
trans_items = () |
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
|
|
|
trans_data_queue.put((batch_idx, trans_items), timeout=1) |
|
|
|
|
|
break |
|
|
|
|
|
except queue.Full: |
|
|
|
|
|
if shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
logger.debug("batch part queue is full!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _data_gathering_loop( |
|
|
|
|
|
trans_data_queues, |
|
|
|
|
|
batch_queue, |
|
|
|
|
|
collator, |
|
|
|
|
|
length, |
|
|
|
|
|
num_workers, |
|
|
|
|
|
shutdown_flag, |
|
|
|
|
|
target_idx, |
|
|
|
|
|
): |
|
|
|
|
|
# Gathering the small pieces of batch data into full batch data |
|
|
|
|
|
while True: |
|
|
|
|
|
if shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
target_batch_idx = target_idx.value |
|
|
|
|
|
|
|
|
|
|
|
if target_batch_idx >= length: |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
full_trans_items = [] |
|
|
|
|
|
for worker_id in range(num_workers): |
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
|
|
|
batch_idx, trans_items = trans_data_queues[worker_id].get( |
|
|
|
|
|
timeout=MP_QUEUE_GET_TIMEOUT |
|
|
|
|
|
) |
|
|
|
|
|
break |
|
|
|
|
|
except queue.Empty: |
|
|
|
|
|
if shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
logger.debug( |
|
|
|
|
|
"worker:{} data queue get timeout! target batch idx:{}".format( |
|
|
|
|
|
worker_id, target_batch_idx |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
if batch_idx != target_batch_idx: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
"Unexperted batch_idx in data gathering loop. worker_id:{}.".format( |
|
|
|
|
|
worker_id |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
full_trans_items.extend(trans_items) |
|
|
|
|
|
|
|
|
|
|
|
# Merge different parts into a batch. |
|
|
|
|
|
full_batch = collator.apply(full_trans_items) |
|
|
|
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
|
|
|
batch_queue.put(full_batch, timeout=1) |
|
|
|
|
|
break |
|
|
|
|
|
except queue.Full: |
|
|
|
|
|
if shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
logger.debug("batch queue is full!") |
|
|
|
|
|
|
|
|
|
|
|
with target_idx.get_lock(): |
|
|
|
|
|
target_idx.value += 1 |
|
|
|
|
|
|
|
|
|
|
|
batch_queue.disconnect_client() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _data_selecting_loop( |
|
|
|
|
|
trans_data_queues, |
|
|
|
|
|
batch_queue, |
|
|
|
|
|
collator, |
|
|
|
|
|
length, |
|
|
|
|
|
num_workers, |
|
|
|
|
|
shutdown_flag, |
|
|
|
|
|
target_idx, |
|
|
|
|
|
): |
|
|
|
|
|
# Make sure that batch is generated exactly with the same order as generated indices |
|
|
|
|
|
while True: |
|
|
|
|
|
if shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
target_batch_idx = target_idx.value |
|
|
|
|
|
|
|
|
|
|
|
if target_batch_idx >= length: |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
target_worker_id = target_batch_idx % num_workers |
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
|
|
|
batch_idx, trans_items = trans_data_queues[target_worker_id].get( |
|
|
|
|
|
timeout=MP_QUEUE_GET_TIMEOUT |
|
|
|
|
|
) |
|
|
|
|
|
batch_data = collator.apply(trans_items) |
|
|
|
|
|
break |
|
|
|
|
|
except queue.Empty: |
|
|
|
|
|
if shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
logger.debug( |
|
|
|
|
|
"worker:{} data queue get timeout! target batch idx:{}".format( |
|
|
|
|
|
target_worker_id, target_batch_idx |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if batch_idx != target_batch_idx: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
"batch_idx {} mismatch the target_batch_idx {}".format( |
|
|
|
|
|
batch_idx, target_batch_idx |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
|
|
|
batch_queue.put(batch_data, timeout=1) |
|
|
|
|
|
break |
|
|
|
|
|
except queue.Full: |
|
|
|
|
|
if shutdown_flag.value == 1: |
|
|
|
|
|
break |
|
|
|
|
|
logger.debug("batch queue is full!") |
|
|
|
|
|
|
|
|
|
|
|
with target_idx.get_lock(): |
|
|
|
|
|
target_idx.value += 1 |
|
|
|
|
|
|
|
|
|
|
|
batch_queue.disconnect_client() |