|
|
@@ -14,6 +14,7 @@ import queue |
|
|
|
import random |
|
|
|
import threading |
|
|
|
import time |
|
|
|
from typing import Callable |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
@@ -36,6 +37,10 @@ logger = get_logger(__name__) |
|
|
|
GLOBAL_TIMEOUT = 5 |
|
|
|
|
|
|
|
|
|
|
|
def raise_timeout_error(): |
|
|
|
raise RuntimeError("dataloader timeout") |
|
|
|
|
|
|
|
|
|
|
|
class DataLoader: |
|
|
|
__initialized = False |
|
|
|
|
|
|
@@ -46,7 +51,8 @@ class DataLoader: |
|
|
|
transform: Transform = None, |
|
|
|
collator: Collator = None, |
|
|
|
num_workers: int = 0, |
|
|
|
timeout: int = GLOBAL_TIMEOUT, |
|
|
|
timeout: int = 0, |
|
|
|
timeout_event: Callable = raise_timeout_error, |
|
|
|
divide: bool = False, |
|
|
|
): |
|
|
|
r""" |
|
|
@@ -71,6 +77,9 @@ class DataLoader: |
|
|
|
:type timeout: int |
|
|
|
:param timeout: if positive, means the timeout value(second) for collecting a |
|
|
|
batch from workers. Default: 0 |
|
|
|
:type timeout_event: Callable |
|
|
|
:param timeout_event: callback function triggered by timeout, default to raise |
|
|
|
runtime error. |
|
|
|
:type divide: bool |
|
|
|
:param divide: define the paralleling strategy in multi-processing mode. |
|
|
|
``True`` means one batch is divided into :attr:`num_workers` pieces, and |
|
|
@@ -92,6 +101,7 @@ class DataLoader: |
|
|
|
|
|
|
|
self.num_workers = num_workers |
|
|
|
self.timeout = timeout |
|
|
|
self.timeout_event = timeout_event |
|
|
|
|
|
|
|
self.divide = divide |
|
|
|
|
|
|
@@ -168,6 +178,7 @@ class _BaseMapDataLoaderIter: |
|
|
|
self.collator = loader.collator |
|
|
|
self.num_workers = loader.num_workers |
|
|
|
self.timeout = loader.timeout |
|
|
|
self.timeout_event = loader.timeout_event |
|
|
|
self.divide = loader.divide |
|
|
|
self.num_processed = 0 |
|
|
|
|
|
|
@@ -306,7 +317,7 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): |
|
|
|
|
|
|
|
logger.debug("all workers are alive.") |
|
|
|
|
|
|
|
def _try_get_next_batch(self): |
|
|
|
def _get_next_batch(self): |
|
|
|
start_time = time.time() |
|
|
|
while True: |
|
|
|
self._check_workers() |
|
|
@@ -319,10 +330,6 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): |
|
|
|
if waited_time > self.timeout: |
|
|
|
raise RuntimeError("get_next_batch timeout!") |
|
|
|
|
|
|
|
def _get_next_batch(self): |
|
|
|
batch_data = self._try_get_next_batch() |
|
|
|
return batch_data |
|
|
|
|
|
|
|
def _shutdown(self): |
|
|
|
with self.shutdown_flag.get_lock(): |
|
|
|
self.shutdown_flag.value = 1 |
|
|
@@ -364,10 +371,24 @@ class _BaseStreamDataLoaderIter: |
|
|
|
self.collator = loader.collator |
|
|
|
self.num_workers = loader.num_workers |
|
|
|
self.timeout = loader.timeout |
|
|
|
self.timeout_event = loader.timeout_event |
|
|
|
|
|
|
|
def _get_next_batch(self): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def _process_raw_data(self, raw_data): |
|
|
|
assert len(raw_data) == 2 and isinstance( |
|
|
|
raw_data[0], bool |
|
|
|
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched." |
|
|
|
if not raw_data[0]: |
|
|
|
data = list((x,) for x in raw_data[1]) |
|
|
|
else: |
|
|
|
data = raw_data[1] |
|
|
|
ret = [] |
|
|
|
for idx in range(len(data[0])): |
|
|
|
ret.append(tuple(e[idx] for e in data)) |
|
|
|
return ret |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
return self |
|
|
|
|
|
|
@@ -380,42 +401,43 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
super().__init__(loader) |
|
|
|
self.dataset_iter = iter(self.dataset) |
|
|
|
self.idx = 0 |
|
|
|
self.data = None |
|
|
|
self.unused = [] |
|
|
|
|
|
|
|
def _get_next_batch(self): |
|
|
|
ret = [] |
|
|
|
while len(ret) != self.sampler.batch_size: |
|
|
|
if self.idx != 0: |
|
|
|
data = self.data |
|
|
|
else: |
|
|
|
try: |
|
|
|
def _try_get_raw_data(self, start_time): |
|
|
|
raw_data = None |
|
|
|
while not raw_data: |
|
|
|
try: |
|
|
|
if self.timeout > 0: |
|
|
|
timer = threading.Timer(self.timeout, thread.interrupt_main) |
|
|
|
timer.start() |
|
|
|
raw_data = next(self.dataset_iter) |
|
|
|
raw_data = next(self.dataset_iter) |
|
|
|
if self.timeout > 0: |
|
|
|
timer.cancel() |
|
|
|
except KeyboardInterrupt: |
|
|
|
raise RuntimeError("get_next_batch timeout!") |
|
|
|
except: |
|
|
|
except KeyboardInterrupt: |
|
|
|
raw_data = self.timeout_event() |
|
|
|
except: |
|
|
|
if self.timeout > 0: |
|
|
|
timer.cancel() |
|
|
|
continue |
|
|
|
assert len(raw_data) == 2 and isinstance( |
|
|
|
raw_data[0], bool |
|
|
|
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched." |
|
|
|
if not raw_data[0]: |
|
|
|
data = list((x,) for x in raw_data[1]) |
|
|
|
else: |
|
|
|
data = raw_data[1] |
|
|
|
for idx in range(self.idx, len(data[0])): |
|
|
|
trans_data = self.transform.apply(tuple(e[idx] for e in data)) |
|
|
|
ret.append(trans_data) |
|
|
|
if len(ret) == self.sampler.batch_size: |
|
|
|
if idx + 1 == len(data[0]): |
|
|
|
self.idx = 0 |
|
|
|
self.data = None |
|
|
|
else: |
|
|
|
self.idx = idx |
|
|
|
self.data = data |
|
|
|
break |
|
|
|
waited_time = time.time() - start_time |
|
|
|
if waited_time > self.timeout: |
|
|
|
raw_data = self.timeout_event() |
|
|
|
return raw_data |
|
|
|
|
|
|
|
def _get_next_batch(self): |
|
|
|
ret = [] |
|
|
|
start_time = time.time() |
|
|
|
while len(ret) < self.sampler.batch_size: |
|
|
|
if len(self.unused) != 0: |
|
|
|
batch_data = self.unused |
|
|
|
else: |
|
|
|
raw_data = self._try_get_raw_data(start_time) |
|
|
|
batch_data = self._process_raw_data(raw_data) |
|
|
|
|
|
|
|
while len(batch_data) != 0 and len(ret) < self.sampler.batch_size: |
|
|
|
data = batch_data.pop() |
|
|
|
ret.append(self.transform.apply(data)) |
|
|
|
self.unused = batch_data |
|
|
|
|
|
|
|
return self.collator.apply(ret) |
|
|
|
|
|
|
|
|
|
|
@@ -440,49 +462,52 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
|
|
|
|
self.batch_queue = PlasmaShmQueue(maxsize=2) |
|
|
|
|
|
|
|
self.recieve_worker = multiprocessing.Process(target=self._recieve, daemon=True) |
|
|
|
self.recieve_worker = multiprocessing.Process( |
|
|
|
target=self._worker_to_raw_data_queues, daemon=True |
|
|
|
) |
|
|
|
self.recieve_worker.start() |
|
|
|
|
|
|
|
self.transform_workers = [] |
|
|
|
for worker_id in range(self.num_workers): |
|
|
|
worker = multiprocessing.Process( |
|
|
|
target=self._transform, args=(worker_id,), daemon=True |
|
|
|
target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True |
|
|
|
) |
|
|
|
worker.start() |
|
|
|
self.transform_workers.append(worker) |
|
|
|
|
|
|
|
self.collect_worker = multiprocessing.Process(target=self._collect, daemon=True) |
|
|
|
self.collect_worker = multiprocessing.Process( |
|
|
|
target=self._worker_to_batch_queue, daemon=True |
|
|
|
) |
|
|
|
self.collect_worker.start() |
|
|
|
|
|
|
|
self.__initialized = True |
|
|
|
|
|
|
|
def _recieve(self): |
|
|
|
def _put_raw_data_queues(self, raw_data, qidx): |
|
|
|
batch_data = self._process_raw_data(raw_data) |
|
|
|
for data in batch_data: |
|
|
|
while True: |
|
|
|
qidx = qidx % self.num_workers |
|
|
|
try: |
|
|
|
self.raw_data_queues[qidx].put(data) |
|
|
|
break |
|
|
|
except queue.Full: |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
logger.debug("raw data queue %d is full" % qidx) |
|
|
|
finally: |
|
|
|
qidx += 1 |
|
|
|
return qidx |
|
|
|
|
|
|
|
def _worker_to_raw_data_queues(self): |
|
|
|
dataset_iter = iter(self.dataset) |
|
|
|
cnt = -1 |
|
|
|
qidx = 0 |
|
|
|
while True: |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
raw_data = next(dataset_iter) |
|
|
|
assert len(raw_data) == 2 and isinstance( |
|
|
|
raw_data[0], bool |
|
|
|
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched." |
|
|
|
if not raw_data[0]: |
|
|
|
data = list((x,) for x in raw_data[1]) |
|
|
|
else: |
|
|
|
data = raw_data[1] |
|
|
|
for idx in range(len(data[0])): |
|
|
|
while True: |
|
|
|
cnt += 1 |
|
|
|
qid = cnt % self.num_workers |
|
|
|
try: |
|
|
|
self.raw_data_queues[qid].put(tuple(e[idx] for e in data)) |
|
|
|
break |
|
|
|
except queue.Full: |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
logger.debug("raw data queue is full") |
|
|
|
qidx = self._put_raw_data_queues(raw_data, qidx) |
|
|
|
|
|
|
|
def _transform(self, worker_id): |
|
|
|
def _worker_to_trans_data_queues(self, worker_id): |
|
|
|
while True: |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
@@ -500,7 +525,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
break |
|
|
|
logger.debug("batch queue if full") |
|
|
|
|
|
|
|
def _collect(self): |
|
|
|
def _worker_to_batch_queue(self): |
|
|
|
cnt = -1 |
|
|
|
trans_items = [] |
|
|
|
while True: |
|
|
@@ -541,7 +566,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
"worker: {} died. {}".format(worker_id, exitcode) |
|
|
|
) |
|
|
|
|
|
|
|
def _try_get_next_batch(self): |
|
|
|
def _get_next_batch(self): |
|
|
|
start_time = time.time() |
|
|
|
while True: |
|
|
|
self._check_workers() |
|
|
@@ -551,11 +576,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
logger.debug("batch queue empty!") |
|
|
|
waited_time = time.time() - start_time |
|
|
|
if self.timeout > 0 and waited_time > self.timeout: |
|
|
|
raise RuntimeError("get_next_batch timeout!") |
|
|
|
|
|
|
|
def _get_next_batch(self): |
|
|
|
batch_data = self._try_get_next_batch() |
|
|
|
return batch_data |
|
|
|
self._put_raw_data_queues(self.timeout_event(), 0) |
|
|
|
|
|
|
|
def _shutdown(self): |
|
|
|
with self.shutdown_flag.get_lock(): |
|
|
|