Browse Source

refactor(mge/data): Refactor megeninge.data.dataset

GitOrigin-RevId: 1d9c61ce70
release-1.1
Megvii Engine Team 4 years ago
parent
commit
e082e27780
4 changed files with 202 additions and 48 deletions
  1. +1
    -0
      imperative/python/megengine/data/__init__.py
  2. +113
    -43
      imperative/python/megengine/data/dataloader.py
  3. +5
    -2
      imperative/python/megengine/data/sampler.py
  4. +83
    -3
      imperative/python/test/unit/data/test_dataloader.py

+ 1
- 0
imperative/python/megengine/data/__init__.py View File

@@ -10,6 +10,7 @@ from .collator import Collator
from .dataloader import DataLoader from .dataloader import DataLoader
from .sampler import ( from .sampler import (
Infinite, Infinite,
MapSampler,
RandomSampler, RandomSampler,
ReplacementSampler, ReplacementSampler,
Sampler, Sampler,


+ 113
- 43
imperative/python/megengine/data/dataloader.py View File

@@ -20,7 +20,7 @@ from ..logger import get_logger
from ..random.rng import _random_seed_generator from ..random.rng import _random_seed_generator
from .collator import Collator from .collator import Collator
from .dataset import Dataset, MapDataset, StreamDataset from .dataset import Dataset, MapDataset, StreamDataset
from .sampler import Sampler, SequentialSampler, StreamSampler
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform from .transform import PseudoTransform, Transform


logger = get_logger(__name__) logger = get_logger(__name__)
@@ -88,17 +88,24 @@ class DataLoader:


self.divide = divide self.divide = divide


if sampler is None:
if isinstance(dataset, MapDataset):
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False)
elif isinstance(dataset, StreamDataset):
self.sampler = StreamSampler(batch_size=1)
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)
if isinstance(dataset, MapDataset):
self.sampler = (
sampler
if sampler
else SequentialSampler(dataset, batch_size=1, drop_last=False)
)
assert isinstance(
self.sampler, MapSampler
), "types of dataset and sampler do not match"
elif isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else: else:
self.sampler = sampler
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)


if divide: if divide:
if self.sampler.batch_size <= self.num_workers: if self.sampler.batch_size <= self.num_workers:
@@ -352,7 +359,6 @@ class _BaseStreamDataLoaderIter:
self.collator = loader.collator self.collator = loader.collator
self.num_workers = loader.num_workers self.num_workers = loader.num_workers
self.timeout = loader.timeout self.timeout = loader.timeout
self.post_process = self.dataset.post_process


def _get_next_batch(self): def _get_next_batch(self):
raise NotImplementedError raise NotImplementedError
@@ -361,13 +367,15 @@ class _BaseStreamDataLoaderIter:
return self return self


def __next__(self): def __next__(self):
return self.post_process(self._get_next_batch())
return self._get_next_batch()




class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def __init__(self, loader): def __init__(self, loader):
super().__init__(loader) super().__init__(loader)
self.dataset_iter = iter(self.dataset) self.dataset_iter = iter(self.dataset)
self.idx = 0
self.data = None


def _get_next_batch(self): def _get_next_batch(self):
ret = [] ret = []
@@ -376,11 +384,30 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
waited_time = time.time() - start_time waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout: if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!") raise RuntimeError("get_next_batch timeout!")
item = next(self.dataset_iter)
for idx in range(len(item[0])):
trans_item = self.transform.apply(tuple(e[idx] for e in item))
ret.append(trans_item)
if self.idx != 0:
data = self.data
else:
try:
raw_data = next(self.dataset_iter)
except:
continue
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "raw_data must be a tuple"
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
for idx in range(self.idx, len(data[0])):
trans_data = self.transform.apply(tuple(e[idx] for e in data))
ret.append(trans_data)
if len(ret) == self.sampler.batch_size: if len(ret) == self.sampler.batch_size:
if idx + 1 == len(data[0]):
self.idx = 0
self.data = None
else:
self.idx = idx
self.data = data
break break
return self.collator.apply(ret) return self.collator.apply(ret)


@@ -393,45 +420,80 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):


self.shutdown_flag = multiprocessing.Value("i", 0) self.shutdown_flag = multiprocessing.Value("i", 0)


self.raw_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]

self.trans_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]

# shared-memory queue implemented by pyarrow plasma store # shared-memory queue implemented by pyarrow plasma store
from ._queue import PlasmaShmQueue from ._queue import PlasmaShmQueue


self.batch_queue = PlasmaShmQueue(maxsize=2) self.batch_queue = PlasmaShmQueue(maxsize=2)
self.workers = []
self.worker_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]

self.recieve_worker = multiprocessing.Process(target=self._recieve, daemon=True)
self.recieve_worker.start()

self.transform_workers = []
for worker_id in range(self.num_workers): for worker_id in range(self.num_workers):
worker = multiprocessing.Process( worker = multiprocessing.Process(
target=self._gen_data, args=(worker_id,), daemon=True
target=self._transform, args=(worker_id,), daemon=True
) )
worker.start() worker.start()
self.workers.append(worker)
self.collator_worker = multiprocessing.Process(
target=self._gen_batch, daemon=True
)
self.collator_worker.start()
self.transform_workers.append(worker)

self.collect_worker = multiprocessing.Process(target=self._collect, daemon=True)
self.collect_worker.start()


self.__initialized = True self.__initialized = True


def _gen_data(self, worker_id):
def _recieve(self):
dataset_iter = iter(self.dataset) dataset_iter = iter(self.dataset)
cnt = -1
while True: while True:
if self.shutdown_flag.value == 1: if self.shutdown_flag.value == 1:
break break
item = next(dataset_iter)
for idx in range(len(item[0])):
trans_item = self.transform.apply(tuple(e[idx] for e in item))
raw_data = next(dataset_iter)
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "raw_data must be a tuple"
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
for idx in range(len(data[0])):
while True: while True:
cnt += 1
qid = cnt % self.num_workers
try: try:
self.worker_queues[worker_id].put(trans_item)
self.raw_data_queues[qid].put(tuple(e[idx] for e in data))
break break
except queue.Full: except queue.Full:
if self.shutdown_flag.value == 1: if self.shutdown_flag.value == 1:
break break
logger.debug("batch part queue is full")
logger.debug("raw data queue is full")


def _gen_batch(self):
def _transform(self, worker_id):
while True:
if self.shutdown_flag.value == 1:
break
try:
data = self.raw_data_queues[worker_id].get(timeout=MP_QUEUE_GET_TIMEOUT)
except queue.Empty:
continue
trans_data = self.transform.apply(data)
while True:
try:
self.trans_data_queues[worker_id].put(trans_data)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue if full")

def _collect(self):
cnt = -1 cnt = -1
trans_items = [] trans_items = []
while True: while True:
@@ -440,7 +502,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
cnt += 1 cnt += 1
queue_id = cnt % self.num_workers queue_id = cnt % self.num_workers
try: try:
trans_item = self.worker_queues[queue_id].get(
trans_item = self.trans_data_queues[queue_id].get(
timeout=MP_QUEUE_GET_TIMEOUT timeout=MP_QUEUE_GET_TIMEOUT
) )
except queue.Empty: except queue.Empty:
@@ -459,12 +521,12 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
trans_items = [] trans_items = []


def _check_workers(self): def _check_workers(self):
if not self.collator_worker.is_alive():
exitcode = self.collator_worker.exitcode
if not self.collect_worker.is_alive():
exitcode = self.collect_worker.exitcode
if exitcode != 0: if exitcode != 0:
raise RuntimeError("collator worker died. {}".format(exitcode)) raise RuntimeError("collator worker died. {}".format(exitcode))


for worker_id, worker in enumerate(self.workers):
for worker_id, worker in enumerate(self.transform_workers):
if not worker.is_alive(): if not worker.is_alive():
exitcode = worker.exitcode exitcode = worker.exitcode
if exitcode != 0: if exitcode != 0:
@@ -492,16 +554,24 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
with self.shutdown_flag.get_lock(): with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1 self.shutdown_flag.value = 1


if self.collator_worker.is_alive():
self.collator_worker.terminate()
self.collator_worker.join()
if self.recieve_worker.is_alive():
self.recieve_worker.terminate()
self.recieve_worker.join()


for worker in self.workers:
if self.collect_worker.is_alive():
self.collect_worker.terminate()
self.collect_worker.join()

for worker in self.transform_workers:
if worker.is_alive(): if worker.is_alive():
worker.terminate() worker.terminate()
worker.join() worker.join()


for q in self.worker_queues:
for q in self.raw_data_queues:
q.cancel_join_thread()
q.close()

for q in self.trans_data_queues:
q.cancel_join_thread() q.cancel_join_thread()
q.close() q.close()




+ 5
- 2
imperative/python/megengine/data/sampler.py View File

@@ -161,10 +161,13 @@ class StreamSampler(Sampler):


.. warning:: .. warning::


In the case of multiple workers, sampler should ensure that each worker gets
In the case of multiple machines, sampler should ensure that each worker gets
different data. But this class cannot do it yet, please build your own different data. But this class cannot do it yet, please build your own
dataset and sampler to achieve this goal. dataset and sampler to achieve this goal.


Usually, meth::`~.StreamDataset.__iter__` can return different iterator by
``rank = dist.get_rank()``. So that they will get different data.

""" """


def __init__(self, batch_size=1): def __init__(self, batch_size=1):
@@ -174,7 +177,7 @@ class StreamSampler(Sampler):
return self return self


def __next__(self): def __next__(self):
return range(self.batch_size)
return iter(range(self.batch_size))




class SequentialSampler(MapSampler): class SequentialSampler(MapSampler):


+ 83
- 3
imperative/python/test/unit/data/test_dataloader.py View File

@@ -15,9 +15,15 @@ import pytest


from megengine.data.collator import Collator from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader from megengine.data.dataloader import DataLoader
from megengine.data.dataset import ArrayDataset
from megengine.data.sampler import RandomSampler, SequentialSampler
from megengine.data.transform import PseudoTransform, Transform
from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import (
Compose,
Normalize,
PseudoTransform,
ToMode,
Transform,
)




def init_dataset(): def init_dataset():
@@ -54,6 +60,80 @@ def test_dataloader_init():
assert len(dataloader) == 16 assert len(dataloader) == 16




class MyStream(StreamDataset):
def __init__(self, number, batch=False, error=False):
self.number = number
self.batch = batch
self.error = error

def __iter__(self):
for cnt in range(self.number):
if self.batch:
data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (32, 32, 3), dtype="uint8")
if self.error:
yield (data, cnt)
else:
yield (False, (data, cnt))
raise StopIteration


@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
num_workers=num_workers,
)

check_set = set()

for step, data in enumerate(dataloader):
if step == 10:
break
assert data[0].shape == (4, 3, 32, 32)
assert data[1].shape == (4,)
for i in data[1]:
assert i not in check_set
check_set.add(i)


def test_stream_dataloader_error():
dataset = MyStream(100, error=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)


@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False)
sampler = StreamSampler(batch_size=4)

class TimeoutTransform(Transform):
def __init__(self):
pass

def apply(self, input):
time.sleep(10)
return input

dataloader = DataLoader(
dataset, sampler, TimeoutTransform(), num_workers=num_workers, timeout=5
)
with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader)
next(data_iter)


def test_dataloader_serial(): def test_dataloader_serial():
dataset = init_dataset() dataset = init_dataset()
dataloader = DataLoader( dataloader = DataLoader(


Loading…
Cancel
Save