|
|
@@ -20,7 +20,7 @@ from ..logger import get_logger |
|
|
|
from ..random.rng import _random_seed_generator |
|
|
|
from .collator import Collator |
|
|
|
from .dataset import Dataset, MapDataset, StreamDataset |
|
|
|
from .sampler import Sampler, SequentialSampler, StreamSampler |
|
|
|
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler |
|
|
|
from .transform import PseudoTransform, Transform |
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
@@ -88,17 +88,24 @@ class DataLoader: |
|
|
|
|
|
|
|
self.divide = divide |
|
|
|
|
|
|
|
if sampler is None: |
|
|
|
if isinstance(dataset, MapDataset): |
|
|
|
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) |
|
|
|
elif isinstance(dataset, StreamDataset): |
|
|
|
self.sampler = StreamSampler(batch_size=1) |
|
|
|
else: |
|
|
|
raise TypeError( |
|
|
|
"can not recognize this kind of dataset: %s" % type(dataset) |
|
|
|
) |
|
|
|
if isinstance(dataset, MapDataset): |
|
|
|
self.sampler = ( |
|
|
|
sampler |
|
|
|
if sampler |
|
|
|
else SequentialSampler(dataset, batch_size=1, drop_last=False) |
|
|
|
) |
|
|
|
assert isinstance( |
|
|
|
self.sampler, MapSampler |
|
|
|
), "types of dataset and sampler do not match" |
|
|
|
elif isinstance(dataset, StreamDataset): |
|
|
|
self.sampler = sampler if sampler else StreamSampler(batch_size=1) |
|
|
|
assert isinstance( |
|
|
|
self.sampler, StreamSampler |
|
|
|
), "types of dataset and sampler do not match" |
|
|
|
else: |
|
|
|
self.sampler = sampler |
|
|
|
raise TypeError( |
|
|
|
"can not recognize this kind of dataset: %s" % type(dataset) |
|
|
|
) |
|
|
|
|
|
|
|
if divide: |
|
|
|
if self.sampler.batch_size <= self.num_workers: |
|
|
@@ -352,7 +359,6 @@ class _BaseStreamDataLoaderIter: |
|
|
|
self.collator = loader.collator |
|
|
|
self.num_workers = loader.num_workers |
|
|
|
self.timeout = loader.timeout |
|
|
|
self.post_process = self.dataset.post_process |
|
|
|
|
|
|
|
def _get_next_batch(self): |
|
|
|
raise NotImplementedError |
|
|
@@ -361,13 +367,15 @@ class _BaseStreamDataLoaderIter: |
|
|
|
return self |
|
|
|
|
|
|
|
def __next__(self): |
|
|
|
return self.post_process(self._get_next_batch()) |
|
|
|
return self._get_next_batch() |
|
|
|
|
|
|
|
|
|
|
|
class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
def __init__(self, loader): |
|
|
|
super().__init__(loader) |
|
|
|
self.dataset_iter = iter(self.dataset) |
|
|
|
self.idx = 0 |
|
|
|
self.data = None |
|
|
|
|
|
|
|
def _get_next_batch(self): |
|
|
|
ret = [] |
|
|
@@ -376,11 +384,30 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
waited_time = time.time() - start_time |
|
|
|
if self.timeout > 0 and waited_time > self.timeout: |
|
|
|
raise RuntimeError("get_next_batch timeout!") |
|
|
|
item = next(self.dataset_iter) |
|
|
|
for idx in range(len(item[0])): |
|
|
|
trans_item = self.transform.apply(tuple(e[idx] for e in item)) |
|
|
|
ret.append(trans_item) |
|
|
|
if self.idx != 0: |
|
|
|
data = self.data |
|
|
|
else: |
|
|
|
try: |
|
|
|
raw_data = next(self.dataset_iter) |
|
|
|
except: |
|
|
|
continue |
|
|
|
assert len(raw_data) == 2 and isinstance( |
|
|
|
raw_data[0], bool |
|
|
|
), "raw_data must be a tuple" |
|
|
|
if not raw_data[0]: |
|
|
|
data = list((x,) for x in raw_data[1]) |
|
|
|
else: |
|
|
|
data = raw_data[1] |
|
|
|
for idx in range(self.idx, len(data[0])): |
|
|
|
trans_data = self.transform.apply(tuple(e[idx] for e in data)) |
|
|
|
ret.append(trans_data) |
|
|
|
if len(ret) == self.sampler.batch_size: |
|
|
|
if idx + 1 == len(data[0]): |
|
|
|
self.idx = 0 |
|
|
|
self.data = None |
|
|
|
else: |
|
|
|
self.idx = idx |
|
|
|
self.data = data |
|
|
|
break |
|
|
|
return self.collator.apply(ret) |
|
|
|
|
|
|
@@ -393,45 +420,80 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
|
|
|
|
self.shutdown_flag = multiprocessing.Value("i", 0) |
|
|
|
|
|
|
|
self.raw_data_queues = [ |
|
|
|
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) |
|
|
|
] |
|
|
|
|
|
|
|
self.trans_data_queues = [ |
|
|
|
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) |
|
|
|
] |
|
|
|
|
|
|
|
# shared-memory queue implemented by pyarrow plasma store |
|
|
|
from ._queue import PlasmaShmQueue |
|
|
|
|
|
|
|
self.batch_queue = PlasmaShmQueue(maxsize=2) |
|
|
|
self.workers = [] |
|
|
|
self.worker_queues = [ |
|
|
|
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) |
|
|
|
] |
|
|
|
|
|
|
|
self.recieve_worker = multiprocessing.Process(target=self._recieve, daemon=True) |
|
|
|
self.recieve_worker.start() |
|
|
|
|
|
|
|
self.transform_workers = [] |
|
|
|
for worker_id in range(self.num_workers): |
|
|
|
worker = multiprocessing.Process( |
|
|
|
target=self._gen_data, args=(worker_id,), daemon=True |
|
|
|
target=self._transform, args=(worker_id,), daemon=True |
|
|
|
) |
|
|
|
worker.start() |
|
|
|
self.workers.append(worker) |
|
|
|
self.collator_worker = multiprocessing.Process( |
|
|
|
target=self._gen_batch, daemon=True |
|
|
|
) |
|
|
|
self.collator_worker.start() |
|
|
|
self.transform_workers.append(worker) |
|
|
|
|
|
|
|
self.collect_worker = multiprocessing.Process(target=self._collect, daemon=True) |
|
|
|
self.collect_worker.start() |
|
|
|
|
|
|
|
self.__initialized = True |
|
|
|
|
|
|
|
def _gen_data(self, worker_id): |
|
|
|
def _recieve(self): |
|
|
|
dataset_iter = iter(self.dataset) |
|
|
|
cnt = -1 |
|
|
|
while True: |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
item = next(dataset_iter) |
|
|
|
for idx in range(len(item[0])): |
|
|
|
trans_item = self.transform.apply(tuple(e[idx] for e in item)) |
|
|
|
raw_data = next(dataset_iter) |
|
|
|
assert len(raw_data) == 2 and isinstance( |
|
|
|
raw_data[0], bool |
|
|
|
), "raw_data must be a tuple" |
|
|
|
if not raw_data[0]: |
|
|
|
data = list((x,) for x in raw_data[1]) |
|
|
|
else: |
|
|
|
data = raw_data[1] |
|
|
|
for idx in range(len(data[0])): |
|
|
|
while True: |
|
|
|
cnt += 1 |
|
|
|
qid = cnt % self.num_workers |
|
|
|
try: |
|
|
|
self.worker_queues[worker_id].put(trans_item) |
|
|
|
self.raw_data_queues[qid].put(tuple(e[idx] for e in data)) |
|
|
|
break |
|
|
|
except queue.Full: |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
logger.debug("batch part queue is full") |
|
|
|
logger.debug("raw data queue is full") |
|
|
|
|
|
|
|
def _gen_batch(self): |
|
|
|
def _transform(self, worker_id): |
|
|
|
while True: |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
try: |
|
|
|
data = self.raw_data_queues[worker_id].get(timeout=MP_QUEUE_GET_TIMEOUT) |
|
|
|
except queue.Empty: |
|
|
|
continue |
|
|
|
trans_data = self.transform.apply(data) |
|
|
|
while True: |
|
|
|
try: |
|
|
|
self.trans_data_queues[worker_id].put(trans_data) |
|
|
|
break |
|
|
|
except queue.Full: |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
logger.debug("batch queue if full") |
|
|
|
|
|
|
|
def _collect(self): |
|
|
|
cnt = -1 |
|
|
|
trans_items = [] |
|
|
|
while True: |
|
|
@@ -440,7 +502,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
cnt += 1 |
|
|
|
queue_id = cnt % self.num_workers |
|
|
|
try: |
|
|
|
trans_item = self.worker_queues[queue_id].get( |
|
|
|
trans_item = self.trans_data_queues[queue_id].get( |
|
|
|
timeout=MP_QUEUE_GET_TIMEOUT |
|
|
|
) |
|
|
|
except queue.Empty: |
|
|
@@ -459,12 +521,12 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
trans_items = [] |
|
|
|
|
|
|
|
def _check_workers(self): |
|
|
|
if not self.collator_worker.is_alive(): |
|
|
|
exitcode = self.collator_worker.exitcode |
|
|
|
if not self.collect_worker.is_alive(): |
|
|
|
exitcode = self.collect_worker.exitcode |
|
|
|
if exitcode != 0: |
|
|
|
raise RuntimeError("collator worker died. {}".format(exitcode)) |
|
|
|
|
|
|
|
for worker_id, worker in enumerate(self.workers): |
|
|
|
for worker_id, worker in enumerate(self.transform_workers): |
|
|
|
if not worker.is_alive(): |
|
|
|
exitcode = worker.exitcode |
|
|
|
if exitcode != 0: |
|
|
@@ -492,16 +554,24 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
with self.shutdown_flag.get_lock(): |
|
|
|
self.shutdown_flag.value = 1 |
|
|
|
|
|
|
|
if self.collator_worker.is_alive(): |
|
|
|
self.collator_worker.terminate() |
|
|
|
self.collator_worker.join() |
|
|
|
if self.recieve_worker.is_alive(): |
|
|
|
self.recieve_worker.terminate() |
|
|
|
self.recieve_worker.join() |
|
|
|
|
|
|
|
for worker in self.workers: |
|
|
|
if self.collect_worker.is_alive(): |
|
|
|
self.collect_worker.terminate() |
|
|
|
self.collect_worker.join() |
|
|
|
|
|
|
|
for worker in self.transform_workers: |
|
|
|
if worker.is_alive(): |
|
|
|
worker.terminate() |
|
|
|
worker.join() |
|
|
|
|
|
|
|
for q in self.worker_queues: |
|
|
|
for q in self.raw_data_queues: |
|
|
|
q.cancel_join_thread() |
|
|
|
q.close() |
|
|
|
|
|
|
|
for q in self.trans_data_queues: |
|
|
|
q.cancel_join_thread() |
|
|
|
q.close() |
|
|
|
|
|
|
|