|
|
@@ -12,6 +12,7 @@ import multiprocessing |
|
|
|
import platform |
|
|
|
import queue |
|
|
|
import random |
|
|
|
import threading |
|
|
|
import time |
|
|
|
|
|
|
|
import numpy as np |
|
|
@@ -23,10 +24,16 @@ from .dataset import Dataset, StreamDataset |
|
|
|
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler |
|
|
|
from .transform import PseudoTransform, Transform |
|
|
|
|
|
|
|
try: |
|
|
|
import thread |
|
|
|
except: |
|
|
|
import _thread as thread |
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
MP_QUEUE_GET_TIMEOUT = 5 |
|
|
|
GLOBAL_TIMEOUT = 5 |
|
|
|
|
|
|
|
|
|
|
|
class DataLoader: |
|
|
@@ -39,7 +46,7 @@ class DataLoader: |
|
|
|
transform: Transform = None, |
|
|
|
collator: Collator = None, |
|
|
|
num_workers: int = 0, |
|
|
|
timeout: int = 0, |
|
|
|
timeout: int = GLOBAL_TIMEOUT, |
|
|
|
divide: bool = False, |
|
|
|
): |
|
|
|
r""" |
|
|
@@ -377,21 +384,23 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
|
|
|
|
def _get_next_batch(self): |
|
|
|
ret = [] |
|
|
|
start_time = time.time() |
|
|
|
while len(ret) != self.sampler.batch_size: |
|
|
|
waited_time = time.time() - start_time |
|
|
|
if self.timeout > 0 and waited_time > self.timeout: |
|
|
|
raise RuntimeError("get_next_batch timeout!") |
|
|
|
if self.idx != 0: |
|
|
|
data = self.data |
|
|
|
else: |
|
|
|
try: |
|
|
|
timer = threading.Timer(self.timeout, thread.interrupt_main) |
|
|
|
timer.start() |
|
|
|
raw_data = next(self.dataset_iter) |
|
|
|
timer.cancel() |
|
|
|
except KeyboardInterrupt: |
|
|
|
raise RuntimeError("get_next_batch timeout!") |
|
|
|
except: |
|
|
|
timer.cancel() |
|
|
|
continue |
|
|
|
assert len(raw_data) == 2 and isinstance( |
|
|
|
raw_data[0], bool |
|
|
|
), "raw_data must be a tuple" |
|
|
|
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched." |
|
|
|
if not raw_data[0]: |
|
|
|
data = list((x,) for x in raw_data[1]) |
|
|
|
else: |
|
|
@@ -456,7 +465,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
raw_data = next(dataset_iter) |
|
|
|
assert len(raw_data) == 2 and isinstance( |
|
|
|
raw_data[0], bool |
|
|
|
), "raw_data must be a tuple" |
|
|
|
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched." |
|
|
|
if not raw_data[0]: |
|
|
|
data = list((x,) for x in raw_data[1]) |
|
|
|
else: |
|
|
@@ -478,7 +487,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
if self.shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
try: |
|
|
|
data = self.raw_data_queues[worker_id].get(timeout=MP_QUEUE_GET_TIMEOUT) |
|
|
|
data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT) |
|
|
|
except queue.Empty: |
|
|
|
continue |
|
|
|
trans_data = self.transform.apply(data) |
|
|
@@ -501,7 +510,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): |
|
|
|
queue_id = cnt % self.num_workers |
|
|
|
try: |
|
|
|
trans_item = self.trans_data_queues[queue_id].get( |
|
|
|
timeout=MP_QUEUE_GET_TIMEOUT |
|
|
|
timeout=GLOBAL_TIMEOUT |
|
|
|
) |
|
|
|
except queue.Empty: |
|
|
|
continue |
|
|
@@ -622,7 +631,7 @@ def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdow |
|
|
|
if shutdown_flag.value == 1: |
|
|
|
break |
|
|
|
try: |
|
|
|
batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT) |
|
|
|
batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT) |
|
|
|
except queue.Empty: |
|
|
|
continue |
|
|
|
if len(indices) > 0: |
|
|
@@ -665,7 +674,7 @@ def _data_gathering_loop( |
|
|
|
while True: |
|
|
|
try: |
|
|
|
batch_idx, trans_items = trans_data_queues[worker_id].get( |
|
|
|
timeout=MP_QUEUE_GET_TIMEOUT |
|
|
|
timeout=GLOBAL_TIMEOUT |
|
|
|
) |
|
|
|
break |
|
|
|
except queue.Empty: |
|
|
@@ -726,7 +735,7 @@ def _data_selecting_loop( |
|
|
|
while True: |
|
|
|
try: |
|
|
|
batch_idx, trans_items = trans_data_queues[target_worker_id].get( |
|
|
|
timeout=MP_QUEUE_GET_TIMEOUT |
|
|
|
timeout=GLOBAL_TIMEOUT |
|
|
|
) |
|
|
|
batch_data = collator.apply(trans_items) |
|
|
|
break |
|
|
|