diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index ae711326..4c4a8368 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -14,6 +14,7 @@ import queue import random import threading import time +from typing import Callable import numpy as np @@ -36,6 +37,10 @@ logger = get_logger(__name__) GLOBAL_TIMEOUT = 5 +def raise_timeout_error(): + raise RuntimeError("dataloader timeout") + + class DataLoader: __initialized = False @@ -46,7 +51,8 @@ class DataLoader: transform: Transform = None, collator: Collator = None, num_workers: int = 0, - timeout: int = GLOBAL_TIMEOUT, + timeout: int = 0, + timeout_event: Callable = raise_timeout_error, divide: bool = False, ): r""" @@ -71,6 +77,9 @@ class DataLoader: :type timeout: int :param timeout: if positive, means the timeout value(second) for collecting a batch from workers. Default: 0 + :type timeout_event: Callable + :param timeout_event: callback function triggered by timeout, default to raise + runtime error. :type divide: bool :param divide: define the paralleling strategy in multi-processing mode. ``True`` means one batch is divided into :attr:`num_workers` pieces, and @@ -92,6 +101,7 @@ class DataLoader: self.num_workers = num_workers self.timeout = timeout + self.timeout_event = timeout_event self.divide = divide @@ -168,6 +178,7 @@ class _BaseMapDataLoaderIter: self.collator = loader.collator self.num_workers = loader.num_workers self.timeout = loader.timeout + self.timeout_event = loader.timeout_event self.divide = loader.divide self.num_processed = 0 @@ -306,7 +317,7 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): logger.debug("all workers are alive.") - def _try_get_next_batch(self): + def _get_next_batch(self): start_time = time.time() while True: self._check_workers() @@ -319,10 +330,6 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): if waited_time > self.timeout: raise RuntimeError("get_next_batch timeout!") - def _get_next_batch(self): - batch_data = self._try_get_next_batch() - return batch_data - def _shutdown(self): with self.shutdown_flag.get_lock(): self.shutdown_flag.value = 1 @@ -364,10 +371,24 @@ class _BaseStreamDataLoaderIter: self.collator = loader.collator self.num_workers = loader.num_workers self.timeout = loader.timeout + self.timeout_event = loader.timeout_event def _get_next_batch(self): raise NotImplementedError + def _process_raw_data(self, raw_data): + assert len(raw_data) == 2 and isinstance( + raw_data[0], bool + ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched." + if not raw_data[0]: + data = list((x,) for x in raw_data[1]) + else: + data = raw_data[1] + ret = [] + for idx in range(len(data[0])): + ret.append(tuple(e[idx] for e in data)) + return ret + def __iter__(self): return self @@ -380,42 +401,43 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): super().__init__(loader) self.dataset_iter = iter(self.dataset) self.idx = 0 - self.data = None + self.unused = [] - def _get_next_batch(self): - ret = [] - while len(ret) != self.sampler.batch_size: - if self.idx != 0: - data = self.data - else: - try: + def _try_get_raw_data(self, start_time): + raw_data = None + while not raw_data: + try: + if self.timeout > 0: timer = threading.Timer(self.timeout, thread.interrupt_main) timer.start() - raw_data = next(self.dataset_iter) + raw_data = next(self.dataset_iter) + if self.timeout > 0: timer.cancel() - except KeyboardInterrupt: - raise RuntimeError("get_next_batch timeout!") - except: + except KeyboardInterrupt: + raw_data = self.timeout_event() + except: + if self.timeout > 0: timer.cancel() - continue - assert len(raw_data) == 2 and isinstance( - raw_data[0], bool - ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched." - if not raw_data[0]: - data = list((x,) for x in raw_data[1]) - else: - data = raw_data[1] - for idx in range(self.idx, len(data[0])): - trans_data = self.transform.apply(tuple(e[idx] for e in data)) - ret.append(trans_data) - if len(ret) == self.sampler.batch_size: - if idx + 1 == len(data[0]): - self.idx = 0 - self.data = None - else: - self.idx = idx - self.data = data - break + waited_time = time.time() - start_time + if waited_time > self.timeout: + raw_data = self.timeout_event() + return raw_data + + def _get_next_batch(self): + ret = [] + start_time = time.time() + while len(ret) < self.sampler.batch_size: + if len(self.unused) != 0: + batch_data = self.unused + else: + raw_data = self._try_get_raw_data(start_time) + batch_data = self._process_raw_data(raw_data) + + while len(batch_data) != 0 and len(ret) < self.sampler.batch_size: + data = batch_data.pop() + ret.append(self.transform.apply(data)) + self.unused = batch_data + return self.collator.apply(ret) @@ -440,49 +462,52 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): self.batch_queue = PlasmaShmQueue(maxsize=2) - self.recieve_worker = multiprocessing.Process(target=self._recieve, daemon=True) + self.recieve_worker = multiprocessing.Process( + target=self._worker_to_raw_data_queues, daemon=True + ) self.recieve_worker.start() self.transform_workers = [] for worker_id in range(self.num_workers): worker = multiprocessing.Process( - target=self._transform, args=(worker_id,), daemon=True + target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True ) worker.start() self.transform_workers.append(worker) - self.collect_worker = multiprocessing.Process(target=self._collect, daemon=True) + self.collect_worker = multiprocessing.Process( + target=self._worker_to_batch_queue, daemon=True + ) self.collect_worker.start() self.__initialized = True - def _recieve(self): + def _put_raw_data_queues(self, raw_data, qidx): + batch_data = self._process_raw_data(raw_data) + for data in batch_data: + while True: + qidx = qidx % self.num_workers + try: + self.raw_data_queues[qidx].put(data) + break + except queue.Full: + if self.shutdown_flag.value == 1: + break + logger.debug("raw data queue %d is full" % qidx) + finally: + qidx += 1 + return qidx + + def _worker_to_raw_data_queues(self): dataset_iter = iter(self.dataset) - cnt = -1 + qidx = 0 while True: if self.shutdown_flag.value == 1: break raw_data = next(dataset_iter) - assert len(raw_data) == 2 and isinstance( - raw_data[0], bool - ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched." - if not raw_data[0]: - data = list((x,) for x in raw_data[1]) - else: - data = raw_data[1] - for idx in range(len(data[0])): - while True: - cnt += 1 - qid = cnt % self.num_workers - try: - self.raw_data_queues[qid].put(tuple(e[idx] for e in data)) - break - except queue.Full: - if self.shutdown_flag.value == 1: - break - logger.debug("raw data queue is full") + qidx = self._put_raw_data_queues(raw_data, qidx) - def _transform(self, worker_id): + def _worker_to_trans_data_queues(self, worker_id): while True: if self.shutdown_flag.value == 1: break @@ -500,7 +525,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): break logger.debug("batch queue if full") - def _collect(self): + def _worker_to_batch_queue(self): cnt = -1 trans_items = [] while True: @@ -541,7 +566,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): "worker: {} died. {}".format(worker_id, exitcode) ) - def _try_get_next_batch(self): + def _get_next_batch(self): start_time = time.time() while True: self._check_workers() @@ -551,11 +576,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): logger.debug("batch queue empty!") waited_time = time.time() - start_time if self.timeout > 0 and waited_time > self.timeout: - raise RuntimeError("get_next_batch timeout!") - - def _get_next_batch(self): - batch_data = self._try_get_next_batch() - return batch_data + self._put_raw_data_queues(self.timeout_event(), 0) def _shutdown(self): with self.shutdown_flag.get_lock(): diff --git a/imperative/python/megengine/data/dataset/meta_dataset.py b/imperative/python/megengine/data/dataset/meta_dataset.py index dd1f01c9..8673b47c 100644 --- a/imperative/python/megengine/data/dataset/meta_dataset.py +++ b/imperative/python/megengine/data/dataset/meta_dataset.py @@ -43,7 +43,7 @@ class StreamDataset(Dataset): def __iter__(self): pass - def __getitem__(self): + def __getitem__(self, idx): raise AssertionError("can not get item from StreamDataset by index") def __len__(self): diff --git a/imperative/python/test/unit/data/test_dataloader.py b/imperative/python/test/unit/data/test_dataloader.py index 8dd25b68..4897e26e 100644 --- a/imperative/python/test/unit/data/test_dataloader.py +++ b/imperative/python/test/unit/data/test_dataloader.py @@ -61,10 +61,10 @@ def test_dataloader_init(): class MyStream(StreamDataset): - def __init__(self, number, batch=False, error=False, block=False): + def __init__(self, number, batch=False, error_foramt=False, block=False): self.number = number self.batch = batch - self.error = error + self.error_format = error_foramt self.block = block def __iter__(self): @@ -73,11 +73,11 @@ class MyStream(StreamDataset): for _ in range(10): time.sleep(1) if self.batch: - data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8") + data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8") yield (True, (data, [cnt, cnt - self.number])) else: - data = np.random.randint(0, 256, (32, 32, 3), dtype="uint8") - if self.error: + data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8") + if self.error_format: yield (data, cnt) else: yield (False, (data, cnt)) @@ -87,7 +87,7 @@ class MyStream(StreamDataset): @pytest.mark.parametrize("batch", [True, False]) @pytest.mark.parametrize("num_workers", [0, 2]) def test_stream_dataloader(batch, num_workers): - dataset = MyStream(100, batch) + dataset = MyStream(100, batch=batch) sampler = StreamSampler(batch_size=4) dataloader = DataLoader( dataset, @@ -101,7 +101,7 @@ def test_stream_dataloader(batch, num_workers): for step, data in enumerate(dataloader): if step == 10: break - assert data[0].shape == (4, 3, 32, 32) + assert data[0].shape == (4, 3, 2, 2) assert data[1].shape == (4,) for i in data[1]: assert i not in check_set @@ -109,7 +109,7 @@ def test_stream_dataloader(batch, num_workers): def test_stream_dataloader_error(): - dataset = MyStream(100, error=True) + dataset = MyStream(100, error_foramt=True) sampler = StreamSampler(batch_size=4) dataloader = DataLoader(dataset, sampler) with pytest.raises(AssertionError, match=r".*tuple.*"): @@ -122,7 +122,7 @@ def test_stream_dataloader_timeout(num_workers): dataset = MyStream(100, False, block=True) sampler = StreamSampler(batch_size=4) - dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=5) + dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2) with pytest.raises(RuntimeError, match=r".*timeout.*"): data_iter = iter(dataloader) next(data_iter) @@ -264,3 +264,20 @@ def test_dataloader_parallel_multi_instances_multiprocessing(): for p in processes: p.join() + + +@pytest.mark.parametrize("num_workers", [0, 2]) +def test_timeout_event(num_workers): + def cb(): + return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,)))) + + dataset = MyStream(100, block=True) + sampler = StreamSampler(batch_size=4) + + dataloader = DataLoader( + dataset, sampler, num_workers=num_workers, timeout=2, timeout_event=cb + ) + for _, data in enumerate(dataloader): + np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3))) + np.testing.assert_equal(data[1], np.ones(shape=(4,))) + break