|
|
@@ -19,8 +19,8 @@ import numpy as np |
|
|
|
from ..logger import get_logger |
|
|
|
from ..random.rng import _random_seed_generator |
|
|
|
from .collator import Collator |
|
|
|
from .dataset import Dataset |
|
|
|
from .sampler import Sampler, SequentialSampler |
|
|
|
from .dataset import Dataset, MapDataset, StreamDataset |
|
|
|
from .sampler import Sampler, SequentialSampler, StreamSampler |
|
|
|
from .transform import PseudoTransform, Transform |
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
@@ -82,13 +82,21 @@ class DataLoader: |
|
|
|
raise ValueError("divide should not be set to True when num_workers <= 1") |
|
|
|
|
|
|
|
self.dataset = dataset |
|
|
|
|
|
|
|
self.num_workers = num_workers |
|
|
|
self.timeout = timeout |
|
|
|
|
|
|
|
self.divide = divide |
|
|
|
|
|
|
|
if sampler is None: |
|
|
|
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) |
|
|
|
if isinstance(dataset, MapDataset): |
|
|
|
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) |
|
|
|
elif isinstance(dataset, StreamDataset): |
|
|
|
self.sampler = StreamSampler(batch_size=1) |
|
|
|
else: |
|
|
|
raise TypeError( |
|
|
|
"can not recognize this kind of dataset: %s" % type(dataset) |
|
|
|
) |
|
|
|
else: |
|
|
|
self.sampler = sampler |
|
|
|
|
|
|
@@ -120,16 +128,26 @@ class DataLoader: |
|
|
|
"pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero" |
|
|
|
) |
|
|
|
self.num_workers = 0 |
|
|
|
if self.num_workers == 0: |
|
|
|
return _SerialDataLoaderIter(self) |
|
|
|
if isinstance(self.dataset, StreamDataset): |
|
|
|
if not self.num_workers: |
|
|
|
return _SerialStreamDataLoaderIter(self) |
|
|
|
else: |
|
|
|
return _ParallelStreamDataLoaderIter(self) |
|
|
|
elif isinstance(self.dataset, MapDataset): |
|
|
|
if not self.num_workers: |
|
|
|
return _SerialMapDataLoaderIter(self) |
|
|
|
else: |
|
|
|
return _ParallelMapDataLoaderIter(self) |
|
|
|
else: |
|
|
|
return _ParallelDataLoaderIter(self) |
|
|
|
raise TypeError( |
|
|
|
"can not recognize this kind of dataset: %s" % type(self.dataset) |
|
|
|
) |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.sampler) |
|
|
|
|
|
|
|
|
|
|
|
class _BaseDataLoaderIter: |
|
|
|
class _BaseMapDataLoaderIter: |
|
|
|
def __init__(self, loader): |
|
|
|
self.dataset = loader.dataset |
|
|
|
self.sampler = loader.sampler |
|
|
@@ -158,9 +176,9 @@ class _BaseDataLoaderIter: |
|
|
|
return minibatch |
|
|
|
|
|
|
|
|
|
|
|
class _SerialDataLoaderIter(_BaseDataLoaderIter): |
|
|
|
class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): |
|
|
|
def __init__(self, loader): |
|
|
|
super(_SerialDataLoaderIter, self).__init__(loader) |
|
|
|
super(_SerialMapDataLoaderIter, self).__init__(loader) |
|
|
|
self.indices_iter = iter(self.sampler) |
|
|
|
|
|
|
|
def _get_next_batch(self): |
|
|
@@ -170,11 +188,11 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter): |
|
|
|
return self.collator.apply(trans_items) |
|
|
|
|
|
|
|
|
|
|
|
class _ParallelDataLoaderIter(_BaseDataLoaderIter): |
|
|
|
class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): |
|
|
|
__initialized = False |
|
|
|
|
|
|
|
def __init__(self, loader): |
|
|
|
super(_ParallelDataLoaderIter, self).__init__(loader) |
|
|
|
super(_ParallelMapDataLoaderIter, self).__init__(loader) |
|
|
|
|
|
|
|
self.task_queues = [ |
|
|
|
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) |
|
|
@@ -326,6 +344,175 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): |
|
|
|
self._shutdown() |
|
|
|
|
|
|
|
|
|
|
|
class _BaseStreamDataLoaderIter: |
|
|
|
def __init__(self, loader): |
|
|
|
self.dataset = loader.dataset |
|
|
|
self.sampler = loader.sampler |
|
|
|
self.transform = loader.transform |
|
|
|
self.collator = loader.collator |
|
|
|
self.num_workers = loader.num_workers |
|
|
|
self.timeout = loader.timeout |
|
|
|
self.post_process = self.dataset.post_process |
|
|
|
|
|
|
|
def _get_next_batch(self): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
return self |
|
|
|
|
|
|
|
def __next__(self): |
|
|
|
return self.post_process(self._get_next_batch()) |
|
|
|
|
|
|
|
|
|
|
|
class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
def __init__(self, loader): |
|
|
|
super().__init__(loader) |
|
|
|
self.dataset_iter = iter(self.dataset) |
|
|
|
|
|
|
|
def _get_next_batch(self): |
|
|
|
ret = [] |
|
|
|
start_time = time.time() |
|
|
|
while len(ret) != self.sampler.batch_size: |
|
|
|
waited_time = time.time() - start_time |
|
|
|
if self.timeout > 0 and waited_time > self.timeout: |
|
|
|
raise RuntimeError("get_next_batch timeout!") |
|
|
|
item = next(self.dataset_iter) |
|
|
|
for idx in range(len(item[0])): |
|
|
|
trans_item = self.transform.apply(tuple(e[idx] for e in item)) |
|
|
|
ret.append(trans_item) |
|
|
|
if len(ret) == self.sampler.batch_size: |
|
|
|
break |
|
|
|
return self.collator.apply(ret) |
|
|
|
|
|
|
|
|
|
|
|
class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
__initialized = False |
|
|
|
|
|
|
|
def __init__(self, loader): |
|
|
|
super().__init__(loader) |
|
|
|
|
|
|
|
self.shutdown_flag = multiprocessing.Value("i", 0) |
|
|
|
|
|
|
|
# shared-memory queue implemented by pyarrow plasma store |
|
|
|
from ._queue import PlasmaShmQueue |
|
|
|
|
|
|
|
self.batch_queue = PlasmaShmQueue(maxsize=2) |
|
|
|
self.workers = [] |
|
|
|
self.worker_queues = [ |
|
|
|
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) |
|
|
|
] |
|
|
|
for worker_id in range(self.num_workers): |
|
|
|
worker = multiprocessing.Process( |
|
|
|
target=self._gen_data, args=(worker_id,), daemon=True |
|
|
|
) |
|
|
|
worker.start() |
|
|
|
self.workers.append(worker) |
|
|
|
self.collator_worker = multiprocessing.Process( |
|
|
|
target=self._gen_batch, daemon=True |
|
|
|
) |
|
|
|
self.collator_worker.start() |
|
|
|
|
|
|
|
self.__initialized = True |
|
|
|
|
|
|
|
def _gen_data(self, worker_id): |
|
|
|
dataset_iter = iter(self.dataset) |
|
|
|
while True: |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
item = next(dataset_iter) |
|
|
|
for idx in range(len(item[0])): |
|
|
|
trans_item = self.transform.apply(tuple(e[idx] for e in item)) |
|
|
|
while True: |
|
|
|
try: |
|
|
|
self.worker_queues[worker_id].put(trans_item) |
|
|
|
break |
|
|
|
except queue.Full: |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
logger.debug("batch part queue is full") |
|
|
|
|
|
|
|
def _gen_batch(self): |
|
|
|
cnt = -1 |
|
|
|
trans_items = [] |
|
|
|
while True: |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
cnt += 1 |
|
|
|
queue_id = cnt % self.num_workers |
|
|
|
try: |
|
|
|
trans_item = self.worker_queues[queue_id].get( |
|
|
|
timeout=MP_QUEUE_GET_TIMEOUT |
|
|
|
) |
|
|
|
except queue.Empty: |
|
|
|
continue |
|
|
|
trans_items.append(trans_item) |
|
|
|
if len(trans_items) == self.sampler.batch_size: |
|
|
|
batch_data = self.collator.apply(trans_items) |
|
|
|
while True: |
|
|
|
try: |
|
|
|
self.batch_queue.put(batch_data, timeout=1) |
|
|
|
break |
|
|
|
except queue.Full: |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
logger.debug("batch queue is full") |
|
|
|
trans_items = [] |
|
|
|
|
|
|
|
def _check_workers(self): |
|
|
|
if not self.collator_worker.is_alive(): |
|
|
|
exitcode = self.collator_worker.exitcode |
|
|
|
if exitcode != 0: |
|
|
|
raise RuntimeError("collator worker died. {}".format(exitcode)) |
|
|
|
|
|
|
|
for worker_id, worker in enumerate(self.workers): |
|
|
|
if not worker.is_alive(): |
|
|
|
exitcode = worker.exitcode |
|
|
|
if exitcode != 0: |
|
|
|
raise RuntimeError( |
|
|
|
"worker: {} died. {}".format(worker_id, exitcode) |
|
|
|
) |
|
|
|
|
|
|
|
def _try_get_next_batch(self): |
|
|
|
start_time = time.time() |
|
|
|
while True: |
|
|
|
self._check_workers() |
|
|
|
try: |
|
|
|
return self.batch_queue.get(timeout=1) |
|
|
|
except queue.Empty: |
|
|
|
logger.debug("batch queue empty!") |
|
|
|
waited_time = time.time() - start_time |
|
|
|
if self.timeout > 0 and waited_time > self.timeout: |
|
|
|
raise RuntimeError("get_next_batch timeout!") |
|
|
|
|
|
|
|
def _get_next_batch(self): |
|
|
|
batch_data = self._try_get_next_batch() |
|
|
|
return batch_data |
|
|
|
|
|
|
|
def _shutdown(self): |
|
|
|
with self.shutdown_flag.get_lock(): |
|
|
|
self.shutdown_flag.value = 1 |
|
|
|
|
|
|
|
if self.collator_worker.is_alive(): |
|
|
|
self.collator_worker.terminate() |
|
|
|
self.collator_worker.join() |
|
|
|
|
|
|
|
for worker in self.workers: |
|
|
|
if worker.is_alive(): |
|
|
|
worker.terminate() |
|
|
|
worker.join() |
|
|
|
|
|
|
|
for q in self.worker_queues: |
|
|
|
q.cancel_join_thread() |
|
|
|
q.close() |
|
|
|
|
|
|
|
self.batch_queue.cancel_join_thread() |
|
|
|
self.batch_queue.close() |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
if self.__initialized: |
|
|
|
self._shutdown() |
|
|
|
|
|
|
|
|
|
|
|
def _task_feeding_loop( |
|
|
|
indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx |
|
|
|
): |
|
|
|