Browse Source

refactor(mge/data): Refactor megeninge.data.dataset

GitOrigin-RevId: 1d9c61ce70
release-1.1
Megvii Engine Team 4 years ago
parent
commit
e082e27780
4 changed files with 202 additions and 48 deletions
  1. +1
    -0
      imperative/python/megengine/data/__init__.py
  2. +113
    -43
      imperative/python/megengine/data/dataloader.py
  3. +5
    -2
      imperative/python/megengine/data/sampler.py
  4. +83
    -3
      imperative/python/test/unit/data/test_dataloader.py

+ 1
- 0
imperative/python/megengine/data/__init__.py View File

@@ -10,6 +10,7 @@ from .collator import Collator
from .dataloader import DataLoader
from .sampler import (
Infinite,
MapSampler,
RandomSampler,
ReplacementSampler,
Sampler,


+ 113
- 43
imperative/python/megengine/data/dataloader.py View File

@@ -20,7 +20,7 @@ from ..logger import get_logger
from ..random.rng import _random_seed_generator
from .collator import Collator
from .dataset import Dataset, MapDataset, StreamDataset
from .sampler import Sampler, SequentialSampler, StreamSampler
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform

logger = get_logger(__name__)
@@ -88,17 +88,24 @@ class DataLoader:

self.divide = divide

if sampler is None:
if isinstance(dataset, MapDataset):
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False)
elif isinstance(dataset, StreamDataset):
self.sampler = StreamSampler(batch_size=1)
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)
if isinstance(dataset, MapDataset):
self.sampler = (
sampler
if sampler
else SequentialSampler(dataset, batch_size=1, drop_last=False)
)
assert isinstance(
self.sampler, MapSampler
), "types of dataset and sampler do not match"
elif isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else:
self.sampler = sampler
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)

if divide:
if self.sampler.batch_size <= self.num_workers:
@@ -352,7 +359,6 @@ class _BaseStreamDataLoaderIter:
self.collator = loader.collator
self.num_workers = loader.num_workers
self.timeout = loader.timeout
self.post_process = self.dataset.post_process

def _get_next_batch(self):
raise NotImplementedError
@@ -361,13 +367,15 @@ class _BaseStreamDataLoaderIter:
return self

def __next__(self):
return self.post_process(self._get_next_batch())
return self._get_next_batch()


class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def __init__(self, loader):
super().__init__(loader)
self.dataset_iter = iter(self.dataset)
self.idx = 0
self.data = None

def _get_next_batch(self):
ret = []
@@ -376,11 +384,30 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
item = next(self.dataset_iter)
for idx in range(len(item[0])):
trans_item = self.transform.apply(tuple(e[idx] for e in item))
ret.append(trans_item)
if self.idx != 0:
data = self.data
else:
try:
raw_data = next(self.dataset_iter)
except:
continue
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "raw_data must be a tuple"
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
for idx in range(self.idx, len(data[0])):
trans_data = self.transform.apply(tuple(e[idx] for e in data))
ret.append(trans_data)
if len(ret) == self.sampler.batch_size:
if idx + 1 == len(data[0]):
self.idx = 0
self.data = None
else:
self.idx = idx
self.data = data
break
return self.collator.apply(ret)

@@ -393,45 +420,80 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):

self.shutdown_flag = multiprocessing.Value("i", 0)

self.raw_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]

self.trans_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]

# shared-memory queue implemented by pyarrow plasma store
from ._queue import PlasmaShmQueue

self.batch_queue = PlasmaShmQueue(maxsize=2)
self.workers = []
self.worker_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]

self.recieve_worker = multiprocessing.Process(target=self._recieve, daemon=True)
self.recieve_worker.start()

self.transform_workers = []
for worker_id in range(self.num_workers):
worker = multiprocessing.Process(
target=self._gen_data, args=(worker_id,), daemon=True
target=self._transform, args=(worker_id,), daemon=True
)
worker.start()
self.workers.append(worker)
self.collator_worker = multiprocessing.Process(
target=self._gen_batch, daemon=True
)
self.collator_worker.start()
self.transform_workers.append(worker)

self.collect_worker = multiprocessing.Process(target=self._collect, daemon=True)
self.collect_worker.start()

self.__initialized = True

def _gen_data(self, worker_id):
def _recieve(self):
dataset_iter = iter(self.dataset)
cnt = -1
while True:
if self.shutdown_flag.value == 1:
break
item = next(dataset_iter)
for idx in range(len(item[0])):
trans_item = self.transform.apply(tuple(e[idx] for e in item))
raw_data = next(dataset_iter)
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "raw_data must be a tuple"
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
for idx in range(len(data[0])):
while True:
cnt += 1
qid = cnt % self.num_workers
try:
self.worker_queues[worker_id].put(trans_item)
self.raw_data_queues[qid].put(tuple(e[idx] for e in data))
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch part queue is full")
logger.debug("raw data queue is full")

def _gen_batch(self):
def _transform(self, worker_id):
while True:
if self.shutdown_flag.value == 1:
break
try:
data = self.raw_data_queues[worker_id].get(timeout=MP_QUEUE_GET_TIMEOUT)
except queue.Empty:
continue
trans_data = self.transform.apply(data)
while True:
try:
self.trans_data_queues[worker_id].put(trans_data)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue if full")

def _collect(self):
cnt = -1
trans_items = []
while True:
@@ -440,7 +502,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
cnt += 1
queue_id = cnt % self.num_workers
try:
trans_item = self.worker_queues[queue_id].get(
trans_item = self.trans_data_queues[queue_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
)
except queue.Empty:
@@ -459,12 +521,12 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
trans_items = []

def _check_workers(self):
if not self.collator_worker.is_alive():
exitcode = self.collator_worker.exitcode
if not self.collect_worker.is_alive():
exitcode = self.collect_worker.exitcode
if exitcode != 0:
raise RuntimeError("collator worker died. {}".format(exitcode))

for worker_id, worker in enumerate(self.workers):
for worker_id, worker in enumerate(self.transform_workers):
if not worker.is_alive():
exitcode = worker.exitcode
if exitcode != 0:
@@ -492,16 +554,24 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1

if self.collator_worker.is_alive():
self.collator_worker.terminate()
self.collator_worker.join()
if self.recieve_worker.is_alive():
self.recieve_worker.terminate()
self.recieve_worker.join()

for worker in self.workers:
if self.collect_worker.is_alive():
self.collect_worker.terminate()
self.collect_worker.join()

for worker in self.transform_workers:
if worker.is_alive():
worker.terminate()
worker.join()

for q in self.worker_queues:
for q in self.raw_data_queues:
q.cancel_join_thread()
q.close()

for q in self.trans_data_queues:
q.cancel_join_thread()
q.close()



+ 5
- 2
imperative/python/megengine/data/sampler.py View File

@@ -161,10 +161,13 @@ class StreamSampler(Sampler):

.. warning::

In the case of multiple workers, sampler should ensure that each worker gets
In the case of multiple machines, sampler should ensure that each worker gets
different data. But this class cannot do it yet, please build your own
dataset and sampler to achieve this goal.

Usually, meth::`~.StreamDataset.__iter__` can return different iterator by
``rank = dist.get_rank()``. So that they will get different data.

"""

def __init__(self, batch_size=1):
@@ -174,7 +177,7 @@ class StreamSampler(Sampler):
return self

def __next__(self):
return range(self.batch_size)
return iter(range(self.batch_size))


class SequentialSampler(MapSampler):


+ 83
- 3
imperative/python/test/unit/data/test_dataloader.py View File

@@ -15,9 +15,15 @@ import pytest

from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader
from megengine.data.dataset import ArrayDataset
from megengine.data.sampler import RandomSampler, SequentialSampler
from megengine.data.transform import PseudoTransform, Transform
from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import (
Compose,
Normalize,
PseudoTransform,
ToMode,
Transform,
)


def init_dataset():
@@ -54,6 +60,80 @@ def test_dataloader_init():
assert len(dataloader) == 16


class MyStream(StreamDataset):
def __init__(self, number, batch=False, error=False):
self.number = number
self.batch = batch
self.error = error

def __iter__(self):
for cnt in range(self.number):
if self.batch:
data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (32, 32, 3), dtype="uint8")
if self.error:
yield (data, cnt)
else:
yield (False, (data, cnt))
raise StopIteration


@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
num_workers=num_workers,
)

check_set = set()

for step, data in enumerate(dataloader):
if step == 10:
break
assert data[0].shape == (4, 3, 32, 32)
assert data[1].shape == (4,)
for i in data[1]:
assert i not in check_set
check_set.add(i)


def test_stream_dataloader_error():
dataset = MyStream(100, error=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)


@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False)
sampler = StreamSampler(batch_size=4)

class TimeoutTransform(Transform):
def __init__(self):
pass

def apply(self, input):
time.sleep(10)
return input

dataloader = DataLoader(
dataset, sampler, TimeoutTransform(), num_workers=num_workers, timeout=5
)
with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader)
next(data_iter)


def test_dataloader_serial():
dataset = init_dataset()
dataloader = DataLoader(


Loading…
Cancel
Save