Browse Source

fix(mge/data): support timeout for serial stream dataloader

GitOrigin-RevId: 1ae5a8cfda
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
7191c4bd9f
2 changed files with 29 additions and 26 deletions
  1. +22
    -13
      imperative/python/megengine/data/dataloader.py
  2. +7
    -13
      imperative/python/test/unit/data/test_dataloader.py

+ 22
- 13
imperative/python/megengine/data/dataloader.py View File

@@ -12,6 +12,7 @@ import multiprocessing
import platform
import queue
import random
import threading
import time

import numpy as np
@@ -23,10 +24,16 @@ from .dataset import Dataset, StreamDataset
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform

try:
import thread
except:
import _thread as thread


logger = get_logger(__name__)


MP_QUEUE_GET_TIMEOUT = 5
GLOBAL_TIMEOUT = 5


class DataLoader:
@@ -39,7 +46,7 @@ class DataLoader:
transform: Transform = None,
collator: Collator = None,
num_workers: int = 0,
timeout: int = 0,
timeout: int = GLOBAL_TIMEOUT,
divide: bool = False,
):
r"""
@@ -377,21 +384,23 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):

def _get_next_batch(self):
ret = []
start_time = time.time()
while len(ret) != self.sampler.batch_size:
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
if self.idx != 0:
data = self.data
else:
try:
timer = threading.Timer(self.timeout, thread.interrupt_main)
timer.start()
raw_data = next(self.dataset_iter)
timer.cancel()
except KeyboardInterrupt:
raise RuntimeError("get_next_batch timeout!")
except:
timer.cancel()
continue
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "raw_data must be a tuple"
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
@@ -456,7 +465,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
raw_data = next(dataset_iter)
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "raw_data must be a tuple"
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
@@ -478,7 +487,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
if self.shutdown_flag.value == 1:
break
try:
data = self.raw_data_queues[worker_id].get(timeout=MP_QUEUE_GET_TIMEOUT)
data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
except queue.Empty:
continue
trans_data = self.transform.apply(data)
@@ -501,7 +510,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
queue_id = cnt % self.num_workers
try:
trans_item = self.trans_data_queues[queue_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
timeout=GLOBAL_TIMEOUT
)
except queue.Empty:
continue
@@ -622,7 +631,7 @@ def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdow
if shutdown_flag.value == 1:
break
try:
batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT)
batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
except queue.Empty:
continue
if len(indices) > 0:
@@ -665,7 +674,7 @@ def _data_gathering_loop(
while True:
try:
batch_idx, trans_items = trans_data_queues[worker_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
timeout=GLOBAL_TIMEOUT
)
break
except queue.Empty:
@@ -726,7 +735,7 @@ def _data_selecting_loop(
while True:
try:
batch_idx, trans_items = trans_data_queues[target_worker_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
timeout=GLOBAL_TIMEOUT
)
batch_data = collator.apply(trans_items)
break


+ 7
- 13
imperative/python/test/unit/data/test_dataloader.py View File

@@ -61,13 +61,17 @@ def test_dataloader_init():


class MyStream(StreamDataset):
def __init__(self, number, batch=False, error=False):
def __init__(self, number, batch=False, error=False, block=False):
self.number = number
self.batch = batch
self.error = error
self.block = block

def __iter__(self):
for cnt in range(self.number):
if self.block:
for _ in range(10):
time.sleep(1)
if self.batch:
data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
@@ -115,20 +119,10 @@ def test_stream_dataloader_error():

@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False)
dataset = MyStream(100, False, block=True)
sampler = StreamSampler(batch_size=4)

class TimeoutTransform(Transform):
def __init__(self):
pass

def apply(self, input):
time.sleep(10)
return input

dataloader = DataLoader(
dataset, sampler, TimeoutTransform(), num_workers=num_workers, timeout=5
)
dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=5)
with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader)
next(data_iter)


Loading…
Cancel
Save