Browse Source

fix(mge/data): support timeout for serial stream dataloader

GitOrigin-RevId: 1ae5a8cfda
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
7191c4bd9f
2 changed files with 29 additions and 26 deletions
  1. +22
    -13
      imperative/python/megengine/data/dataloader.py
  2. +7
    -13
      imperative/python/test/unit/data/test_dataloader.py

+ 22
- 13
imperative/python/megengine/data/dataloader.py View File

@@ -12,6 +12,7 @@ import multiprocessing
import platform import platform
import queue import queue
import random import random
import threading
import time import time


import numpy as np import numpy as np
@@ -23,10 +24,16 @@ from .dataset import Dataset, StreamDataset
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform from .transform import PseudoTransform, Transform


try:
import thread
except:
import _thread as thread


logger = get_logger(__name__) logger = get_logger(__name__)




MP_QUEUE_GET_TIMEOUT = 5
GLOBAL_TIMEOUT = 5




class DataLoader: class DataLoader:
@@ -39,7 +46,7 @@ class DataLoader:
transform: Transform = None, transform: Transform = None,
collator: Collator = None, collator: Collator = None,
num_workers: int = 0, num_workers: int = 0,
timeout: int = 0,
timeout: int = GLOBAL_TIMEOUT,
divide: bool = False, divide: bool = False,
): ):
r""" r"""
@@ -377,21 +384,23 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):


def _get_next_batch(self): def _get_next_batch(self):
ret = [] ret = []
start_time = time.time()
while len(ret) != self.sampler.batch_size: while len(ret) != self.sampler.batch_size:
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
if self.idx != 0: if self.idx != 0:
data = self.data data = self.data
else: else:
try: try:
timer = threading.Timer(self.timeout, thread.interrupt_main)
timer.start()
raw_data = next(self.dataset_iter) raw_data = next(self.dataset_iter)
timer.cancel()
except KeyboardInterrupt:
raise RuntimeError("get_next_batch timeout!")
except: except:
timer.cancel()
continue continue
assert len(raw_data) == 2 and isinstance( assert len(raw_data) == 2 and isinstance(
raw_data[0], bool raw_data[0], bool
), "raw_data must be a tuple"
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]: if not raw_data[0]:
data = list((x,) for x in raw_data[1]) data = list((x,) for x in raw_data[1])
else: else:
@@ -456,7 +465,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
raw_data = next(dataset_iter) raw_data = next(dataset_iter)
assert len(raw_data) == 2 and isinstance( assert len(raw_data) == 2 and isinstance(
raw_data[0], bool raw_data[0], bool
), "raw_data must be a tuple"
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]: if not raw_data[0]:
data = list((x,) for x in raw_data[1]) data = list((x,) for x in raw_data[1])
else: else:
@@ -478,7 +487,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
if self.shutdown_flag.value == 1: if self.shutdown_flag.value == 1:
break break
try: try:
data = self.raw_data_queues[worker_id].get(timeout=MP_QUEUE_GET_TIMEOUT)
data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
except queue.Empty: except queue.Empty:
continue continue
trans_data = self.transform.apply(data) trans_data = self.transform.apply(data)
@@ -501,7 +510,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
queue_id = cnt % self.num_workers queue_id = cnt % self.num_workers
try: try:
trans_item = self.trans_data_queues[queue_id].get( trans_item = self.trans_data_queues[queue_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
timeout=GLOBAL_TIMEOUT
) )
except queue.Empty: except queue.Empty:
continue continue
@@ -622,7 +631,7 @@ def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdow
if shutdown_flag.value == 1: if shutdown_flag.value == 1:
break break
try: try:
batch_idx, indices = task_queue.get(timeout=MP_QUEUE_GET_TIMEOUT)
batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
except queue.Empty: except queue.Empty:
continue continue
if len(indices) > 0: if len(indices) > 0:
@@ -665,7 +674,7 @@ def _data_gathering_loop(
while True: while True:
try: try:
batch_idx, trans_items = trans_data_queues[worker_id].get( batch_idx, trans_items = trans_data_queues[worker_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
timeout=GLOBAL_TIMEOUT
) )
break break
except queue.Empty: except queue.Empty:
@@ -726,7 +735,7 @@ def _data_selecting_loop(
while True: while True:
try: try:
batch_idx, trans_items = trans_data_queues[target_worker_id].get( batch_idx, trans_items = trans_data_queues[target_worker_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
timeout=GLOBAL_TIMEOUT
) )
batch_data = collator.apply(trans_items) batch_data = collator.apply(trans_items)
break break


+ 7
- 13
imperative/python/test/unit/data/test_dataloader.py View File

@@ -61,13 +61,17 @@ def test_dataloader_init():




class MyStream(StreamDataset): class MyStream(StreamDataset):
def __init__(self, number, batch=False, error=False):
def __init__(self, number, batch=False, error=False, block=False):
self.number = number self.number = number
self.batch = batch self.batch = batch
self.error = error self.error = error
self.block = block


def __iter__(self): def __iter__(self):
for cnt in range(self.number): for cnt in range(self.number):
if self.block:
for _ in range(10):
time.sleep(1)
if self.batch: if self.batch:
data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8") data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number])) yield (True, (data, [cnt, cnt - self.number]))
@@ -115,20 +119,10 @@ def test_stream_dataloader_error():


@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers): def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False)
dataset = MyStream(100, False, block=True)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)


class TimeoutTransform(Transform):
def __init__(self):
pass

def apply(self, input):
time.sleep(10)
return input

dataloader = DataLoader(
dataset, sampler, TimeoutTransform(), num_workers=num_workers, timeout=5
)
dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=5)
with pytest.raises(RuntimeError, match=r".*timeout.*"): with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader) data_iter = iter(dataloader)
next(data_iter) next(data_iter)


Loading…
Cancel
Save