Browse Source

fix(mge/data): add timeout event

GitOrigin-RevId: 43f2ba1456
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
c418d3cd95
3 changed files with 116 additions and 78 deletions
  1. +89
    -68
      imperative/python/megengine/data/dataloader.py
  2. +1
    -1
      imperative/python/megengine/data/dataset/meta_dataset.py
  3. +26
    -9
      imperative/python/test/unit/data/test_dataloader.py

+ 89
- 68
imperative/python/megengine/data/dataloader.py View File

@@ -14,6 +14,7 @@ import queue
import random import random
import threading import threading
import time import time
from typing import Callable


import numpy as np import numpy as np


@@ -36,6 +37,10 @@ logger = get_logger(__name__)
GLOBAL_TIMEOUT = 5 GLOBAL_TIMEOUT = 5




def raise_timeout_error():
raise RuntimeError("dataloader timeout")


class DataLoader: class DataLoader:
__initialized = False __initialized = False


@@ -46,7 +51,8 @@ class DataLoader:
transform: Transform = None, transform: Transform = None,
collator: Collator = None, collator: Collator = None,
num_workers: int = 0, num_workers: int = 0,
timeout: int = GLOBAL_TIMEOUT,
timeout: int = 0,
timeout_event: Callable = raise_timeout_error,
divide: bool = False, divide: bool = False,
): ):
r""" r"""
@@ -71,6 +77,9 @@ class DataLoader:
:type timeout: int :type timeout: int
:param timeout: if positive, means the timeout value(second) for collecting a :param timeout: if positive, means the timeout value(second) for collecting a
batch from workers. Default: 0 batch from workers. Default: 0
:type timeout_event: Callable
:param timeout_event: callback function triggered by timeout, default to raise
runtime error.
:type divide: bool :type divide: bool
:param divide: define the paralleling strategy in multi-processing mode. :param divide: define the paralleling strategy in multi-processing mode.
``True`` means one batch is divided into :attr:`num_workers` pieces, and ``True`` means one batch is divided into :attr:`num_workers` pieces, and
@@ -92,6 +101,7 @@ class DataLoader:


self.num_workers = num_workers self.num_workers = num_workers
self.timeout = timeout self.timeout = timeout
self.timeout_event = timeout_event


self.divide = divide self.divide = divide


@@ -168,6 +178,7 @@ class _BaseMapDataLoaderIter:
self.collator = loader.collator self.collator = loader.collator
self.num_workers = loader.num_workers self.num_workers = loader.num_workers
self.timeout = loader.timeout self.timeout = loader.timeout
self.timeout_event = loader.timeout_event
self.divide = loader.divide self.divide = loader.divide
self.num_processed = 0 self.num_processed = 0


@@ -306,7 +317,7 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):


logger.debug("all workers are alive.") logger.debug("all workers are alive.")


def _try_get_next_batch(self):
def _get_next_batch(self):
start_time = time.time() start_time = time.time()
while True: while True:
self._check_workers() self._check_workers()
@@ -319,10 +330,6 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
if waited_time > self.timeout: if waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!") raise RuntimeError("get_next_batch timeout!")


def _get_next_batch(self):
batch_data = self._try_get_next_batch()
return batch_data

def _shutdown(self): def _shutdown(self):
with self.shutdown_flag.get_lock(): with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1 self.shutdown_flag.value = 1
@@ -364,10 +371,24 @@ class _BaseStreamDataLoaderIter:
self.collator = loader.collator self.collator = loader.collator
self.num_workers = loader.num_workers self.num_workers = loader.num_workers
self.timeout = loader.timeout self.timeout = loader.timeout
self.timeout_event = loader.timeout_event


def _get_next_batch(self): def _get_next_batch(self):
raise NotImplementedError raise NotImplementedError


def _process_raw_data(self, raw_data):
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
ret = []
for idx in range(len(data[0])):
ret.append(tuple(e[idx] for e in data))
return ret

def __iter__(self): def __iter__(self):
return self return self


@@ -380,42 +401,43 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
super().__init__(loader) super().__init__(loader)
self.dataset_iter = iter(self.dataset) self.dataset_iter = iter(self.dataset)
self.idx = 0 self.idx = 0
self.data = None
self.unused = []


def _get_next_batch(self):
ret = []
while len(ret) != self.sampler.batch_size:
if self.idx != 0:
data = self.data
else:
try:
def _try_get_raw_data(self, start_time):
raw_data = None
while not raw_data:
try:
if self.timeout > 0:
timer = threading.Timer(self.timeout, thread.interrupt_main) timer = threading.Timer(self.timeout, thread.interrupt_main)
timer.start() timer.start()
raw_data = next(self.dataset_iter)
raw_data = next(self.dataset_iter)
if self.timeout > 0:
timer.cancel() timer.cancel()
except KeyboardInterrupt:
raise RuntimeError("get_next_batch timeout!")
except:
except KeyboardInterrupt:
raw_data = self.timeout_event()
except:
if self.timeout > 0:
timer.cancel() timer.cancel()
continue
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
for idx in range(self.idx, len(data[0])):
trans_data = self.transform.apply(tuple(e[idx] for e in data))
ret.append(trans_data)
if len(ret) == self.sampler.batch_size:
if idx + 1 == len(data[0]):
self.idx = 0
self.data = None
else:
self.idx = idx
self.data = data
break
waited_time = time.time() - start_time
if waited_time > self.timeout:
raw_data = self.timeout_event()
return raw_data

def _get_next_batch(self):
ret = []
start_time = time.time()
while len(ret) < self.sampler.batch_size:
if len(self.unused) != 0:
batch_data = self.unused
else:
raw_data = self._try_get_raw_data(start_time)
batch_data = self._process_raw_data(raw_data)

while len(batch_data) != 0 and len(ret) < self.sampler.batch_size:
data = batch_data.pop()
ret.append(self.transform.apply(data))
self.unused = batch_data

return self.collator.apply(ret) return self.collator.apply(ret)




@@ -440,49 +462,52 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):


self.batch_queue = PlasmaShmQueue(maxsize=2) self.batch_queue = PlasmaShmQueue(maxsize=2)


self.recieve_worker = multiprocessing.Process(target=self._recieve, daemon=True)
self.recieve_worker = multiprocessing.Process(
target=self._worker_to_raw_data_queues, daemon=True
)
self.recieve_worker.start() self.recieve_worker.start()


self.transform_workers = [] self.transform_workers = []
for worker_id in range(self.num_workers): for worker_id in range(self.num_workers):
worker = multiprocessing.Process( worker = multiprocessing.Process(
target=self._transform, args=(worker_id,), daemon=True
target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
) )
worker.start() worker.start()
self.transform_workers.append(worker) self.transform_workers.append(worker)


self.collect_worker = multiprocessing.Process(target=self._collect, daemon=True)
self.collect_worker = multiprocessing.Process(
target=self._worker_to_batch_queue, daemon=True
)
self.collect_worker.start() self.collect_worker.start()


self.__initialized = True self.__initialized = True


def _recieve(self):
def _put_raw_data_queues(self, raw_data, qidx):
batch_data = self._process_raw_data(raw_data)
for data in batch_data:
while True:
qidx = qidx % self.num_workers
try:
self.raw_data_queues[qidx].put(data)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("raw data queue %d is full" % qidx)
finally:
qidx += 1
return qidx

def _worker_to_raw_data_queues(self):
dataset_iter = iter(self.dataset) dataset_iter = iter(self.dataset)
cnt = -1
qidx = 0
while True: while True:
if self.shutdown_flag.value == 1: if self.shutdown_flag.value == 1:
break break
raw_data = next(dataset_iter) raw_data = next(dataset_iter)
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
for idx in range(len(data[0])):
while True:
cnt += 1
qid = cnt % self.num_workers
try:
self.raw_data_queues[qid].put(tuple(e[idx] for e in data))
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("raw data queue is full")
qidx = self._put_raw_data_queues(raw_data, qidx)


def _transform(self, worker_id):
def _worker_to_trans_data_queues(self, worker_id):
while True: while True:
if self.shutdown_flag.value == 1: if self.shutdown_flag.value == 1:
break break
@@ -500,7 +525,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
break break
logger.debug("batch queue if full") logger.debug("batch queue if full")


def _collect(self):
def _worker_to_batch_queue(self):
cnt = -1 cnt = -1
trans_items = [] trans_items = []
while True: while True:
@@ -541,7 +566,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
"worker: {} died. {}".format(worker_id, exitcode) "worker: {} died. {}".format(worker_id, exitcode)
) )


def _try_get_next_batch(self):
def _get_next_batch(self):
start_time = time.time() start_time = time.time()
while True: while True:
self._check_workers() self._check_workers()
@@ -551,11 +576,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
logger.debug("batch queue empty!") logger.debug("batch queue empty!")
waited_time = time.time() - start_time waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout: if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")

def _get_next_batch(self):
batch_data = self._try_get_next_batch()
return batch_data
self._put_raw_data_queues(self.timeout_event(), 0)


def _shutdown(self): def _shutdown(self):
with self.shutdown_flag.get_lock(): with self.shutdown_flag.get_lock():


+ 1
- 1
imperative/python/megengine/data/dataset/meta_dataset.py View File

@@ -43,7 +43,7 @@ class StreamDataset(Dataset):
def __iter__(self): def __iter__(self):
pass pass


def __getitem__(self):
def __getitem__(self, idx):
raise AssertionError("can not get item from StreamDataset by index") raise AssertionError("can not get item from StreamDataset by index")


def __len__(self): def __len__(self):


+ 26
- 9
imperative/python/test/unit/data/test_dataloader.py View File

@@ -61,10 +61,10 @@ def test_dataloader_init():




class MyStream(StreamDataset): class MyStream(StreamDataset):
def __init__(self, number, batch=False, error=False, block=False):
def __init__(self, number, batch=False, error_foramt=False, block=False):
self.number = number self.number = number
self.batch = batch self.batch = batch
self.error = error
self.error_format = error_foramt
self.block = block self.block = block


def __iter__(self): def __iter__(self):
@@ -73,11 +73,11 @@ class MyStream(StreamDataset):
for _ in range(10): for _ in range(10):
time.sleep(1) time.sleep(1)
if self.batch: if self.batch:
data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8")
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number])) yield (True, (data, [cnt, cnt - self.number]))
else: else:
data = np.random.randint(0, 256, (32, 32, 3), dtype="uint8")
if self.error:
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
if self.error_format:
yield (data, cnt) yield (data, cnt)
else: else:
yield (False, (data, cnt)) yield (False, (data, cnt))
@@ -87,7 +87,7 @@ class MyStream(StreamDataset):
@pytest.mark.parametrize("batch", [True, False]) @pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers): def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch)
dataset = MyStream(100, batch=batch)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
@@ -101,7 +101,7 @@ def test_stream_dataloader(batch, num_workers):
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
if step == 10: if step == 10:
break break
assert data[0].shape == (4, 3, 32, 32)
assert data[0].shape == (4, 3, 2, 2)
assert data[1].shape == (4,) assert data[1].shape == (4,)
for i in data[1]: for i in data[1]:
assert i not in check_set assert i not in check_set
@@ -109,7 +109,7 @@ def test_stream_dataloader(batch, num_workers):




def test_stream_dataloader_error(): def test_stream_dataloader_error():
dataset = MyStream(100, error=True)
dataset = MyStream(100, error_foramt=True)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler) dataloader = DataLoader(dataset, sampler)
with pytest.raises(AssertionError, match=r".*tuple.*"): with pytest.raises(AssertionError, match=r".*tuple.*"):
@@ -122,7 +122,7 @@ def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False, block=True) dataset = MyStream(100, False, block=True)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)


dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=5)
dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2)
with pytest.raises(RuntimeError, match=r".*timeout.*"): with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader) data_iter = iter(dataloader)
next(data_iter) next(data_iter)
@@ -264,3 +264,20 @@ def test_dataloader_parallel_multi_instances_multiprocessing():


for p in processes: for p in processes:
p.join() p.join()


@pytest.mark.parametrize("num_workers", [0, 2])
def test_timeout_event(num_workers):
def cb():
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,))))

dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)

dataloader = DataLoader(
dataset, sampler, num_workers=num_workers, timeout=2, timeout_event=cb
)
for _, data in enumerate(dataloader):
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
np.testing.assert_equal(data[1], np.ones(shape=(4,)))
break

Loading…
Cancel
Save