|
|
@@ -1,4 +1,12 @@ |
|
|
|
# -*- coding: utf-8 -*- |
|
|
|
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") |
|
|
|
# |
|
|
|
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. |
|
|
|
# |
|
|
|
# Unless required by applicable law or agreed to in writing, |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import math |
|
|
|
import os |
|
|
|
import platform |
|
|
|
import time |
|
|
@@ -7,7 +15,7 @@ import numpy as np |
|
|
|
import pytest |
|
|
|
|
|
|
|
from megengine.data.collator import Collator |
|
|
|
from megengine.data.dataloader import DataLoader |
|
|
|
from megengine.data.dataloader import DataLoader, get_worker_info |
|
|
|
from megengine.data.dataset import ArrayDataset, StreamDataset |
|
|
|
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler |
|
|
|
from megengine.data.transform import ( |
|
|
@@ -30,13 +38,9 @@ def init_dataset(): |
|
|
|
def test_dataloader_init(): |
|
|
|
dataset = init_dataset() |
|
|
|
with pytest.raises(ValueError): |
|
|
|
dataloader = DataLoader(dataset, num_workers=2, divide=True) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
dataloader = DataLoader(dataset, num_workers=-1) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
dataloader = DataLoader(dataset, timeout=-1) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
dataloader = DataLoader(dataset, num_workers=0, divide=True) |
|
|
|
|
|
|
|
dataloader = DataLoader(dataset) |
|
|
|
assert isinstance(dataloader.sampler, SequentialSampler) |
|
|
@@ -54,10 +58,8 @@ def test_dataloader_init(): |
|
|
|
|
|
|
|
|
|
|
|
class MyStream(StreamDataset): |
|
|
|
def __init__(self, number, batch=False, error_foramt=False, block=False): |
|
|
|
def __init__(self, number, block=False): |
|
|
|
self.number = number |
|
|
|
self.batch = batch |
|
|
|
self.error_format = error_foramt |
|
|
|
self.block = block |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
@@ -65,22 +67,14 @@ class MyStream(StreamDataset): |
|
|
|
if self.block: |
|
|
|
for _ in range(10): |
|
|
|
time.sleep(1) |
|
|
|
if self.batch: |
|
|
|
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8") |
|
|
|
yield (True, (data, [cnt, cnt - self.number])) |
|
|
|
else: |
|
|
|
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8") |
|
|
|
if self.error_format: |
|
|
|
yield (data, cnt) |
|
|
|
else: |
|
|
|
yield (False, (data, cnt)) |
|
|
|
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8") |
|
|
|
yield (data, cnt) |
|
|
|
raise StopIteration |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch", [True, False]) |
|
|
|
@pytest.mark.parametrize("num_workers", [0, 2]) |
|
|
|
def test_stream_dataloader(batch, num_workers): |
|
|
|
dataset = MyStream(100, batch=batch) |
|
|
|
def test_stream_dataloader(num_workers): |
|
|
|
dataset = MyStream(100) |
|
|
|
sampler = StreamSampler(batch_size=4) |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset, |
|
|
@@ -90,7 +84,6 @@ def test_stream_dataloader(batch, num_workers): |
|
|
|
) |
|
|
|
|
|
|
|
check_set = set() |
|
|
|
|
|
|
|
for step, data in enumerate(dataloader): |
|
|
|
if step == 10: |
|
|
|
break |
|
|
@@ -101,18 +94,9 @@ def test_stream_dataloader(batch, num_workers): |
|
|
|
check_set.add(i) |
|
|
|
|
|
|
|
|
|
|
|
def test_stream_dataloader_error(): |
|
|
|
dataset = MyStream(100, error_foramt=True) |
|
|
|
sampler = StreamSampler(batch_size=4) |
|
|
|
dataloader = DataLoader(dataset, sampler) |
|
|
|
with pytest.raises(AssertionError, match=r".*tuple.*"): |
|
|
|
data_iter = iter(dataloader) |
|
|
|
next(data_iter) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("num_workers", [0, 2]) |
|
|
|
def test_stream_dataloader_timeout(num_workers): |
|
|
|
dataset = MyStream(100, False, block=True) |
|
|
|
dataset = MyStream(100, block=True) |
|
|
|
sampler = StreamSampler(batch_size=4) |
|
|
|
|
|
|
|
dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2) |
|
|
@@ -140,17 +124,6 @@ def test_dataloader_parallel(): |
|
|
|
dataset, |
|
|
|
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), |
|
|
|
num_workers=2, |
|
|
|
divide=False, |
|
|
|
) |
|
|
|
for (data, label) in dataloader: |
|
|
|
assert data.shape == (4, 1, 32, 32) |
|
|
|
assert label.shape == (4,) |
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), |
|
|
|
num_workers=2, |
|
|
|
divide=True, |
|
|
|
) |
|
|
|
for (data, label) in dataloader: |
|
|
|
assert data.shape == (4, 1, 32, 32) |
|
|
@@ -205,7 +178,7 @@ def test_dataloader_parallel_worker_exception(): |
|
|
|
transform=FakeErrorTransform(), |
|
|
|
num_workers=2, |
|
|
|
) |
|
|
|
with pytest.raises(RuntimeError, match=r"worker.*died"): |
|
|
|
with pytest.raises(RuntimeError, match=r"exited unexpectedly"): |
|
|
|
data_iter = iter(dataloader) |
|
|
|
batch_data = next(data_iter) |
|
|
|
|
|
|
@@ -213,26 +186,23 @@ def test_dataloader_parallel_worker_exception(): |
|
|
|
def _multi_instances_parallel_dataloader_worker(): |
|
|
|
dataset = init_dataset() |
|
|
|
|
|
|
|
for divide_flag in [True, False]: |
|
|
|
train_dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), |
|
|
|
num_workers=2, |
|
|
|
divide=divide_flag, |
|
|
|
) |
|
|
|
val_dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
sampler=RandomSampler(dataset, batch_size=10, drop_last=False), |
|
|
|
num_workers=2, |
|
|
|
divide=divide_flag, |
|
|
|
) |
|
|
|
for idx, (data, label) in enumerate(train_dataloader): |
|
|
|
assert data.shape == (4, 1, 32, 32) |
|
|
|
assert label.shape == (4,) |
|
|
|
if idx % 5 == 0: |
|
|
|
for val_data, val_label in val_dataloader: |
|
|
|
assert val_data.shape == (10, 1, 32, 32) |
|
|
|
assert val_label.shape == (10,) |
|
|
|
train_dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), |
|
|
|
num_workers=2, |
|
|
|
) |
|
|
|
val_dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
sampler=RandomSampler(dataset, batch_size=10, drop_last=False), |
|
|
|
num_workers=2, |
|
|
|
) |
|
|
|
for idx, (data, label) in enumerate(train_dataloader): |
|
|
|
assert data.shape == (4, 1, 32, 32) |
|
|
|
assert label.shape == (4,) |
|
|
|
if idx % 5 == 0: |
|
|
|
for val_data, val_label in val_dataloader: |
|
|
|
assert val_data.shape == (10, 1, 32, 32) |
|
|
|
assert val_label.shape == (10,) |
|
|
|
|
|
|
|
|
|
|
|
def test_dataloader_parallel_multi_instances(): |
|
|
@@ -261,18 +231,81 @@ def test_dataloader_parallel_multi_instances_multiprocessing(): |
|
|
|
assert p.exitcode == 0 |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("num_workers", [0, 2]) |
|
|
|
def test_timeout_event(num_workers): |
|
|
|
def cb(): |
|
|
|
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,)))) |
|
|
|
def partition(ls, size): |
|
|
|
return [ls[i : i + size] for i in range(0, len(ls), size)] |
|
|
|
|
|
|
|
dataset = MyStream(100, block=True) |
|
|
|
|
|
|
|
class MyPreStream(StreamDataset): |
|
|
|
def __init__(self, number, block=False): |
|
|
|
self.number = [i for i in range(number)] |
|
|
|
self.block = block |
|
|
|
self.data = [] |
|
|
|
for i in range(100): |
|
|
|
self.data.append(np.random.randint(0, 256, (2, 2, 3), dtype="uint8")) |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
worker_info = get_worker_info() |
|
|
|
per_worker = int(math.ceil((len(self.data)) / float(worker_info.worker))) |
|
|
|
pre_data = iter(partition(self.data, per_worker)[worker_info.idx]) |
|
|
|
pre_cnt = partition(self.number, per_worker)[worker_info.idx] |
|
|
|
for cnt in pre_cnt: |
|
|
|
if self.block: |
|
|
|
for _ in range(10): |
|
|
|
time.sleep(1) |
|
|
|
yield (next(pre_data), cnt) |
|
|
|
raise StopIteration |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif( |
|
|
|
platform.system() == "Windows", |
|
|
|
reason="dataloader do not support parallel on windows", |
|
|
|
) |
|
|
|
def test_prestream_dataloader_multiprocessing(): |
|
|
|
dataset = MyPreStream(100) |
|
|
|
sampler = StreamSampler(batch_size=4) |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
sampler, |
|
|
|
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]), |
|
|
|
num_workers=2, |
|
|
|
parallel_stream=True, |
|
|
|
) |
|
|
|
|
|
|
|
check_set = set() |
|
|
|
|
|
|
|
for step, data in enumerate(dataloader): |
|
|
|
if step == 10: |
|
|
|
break |
|
|
|
assert data[0].shape == (4, 3, 2, 2) |
|
|
|
assert data[1].shape == (4,) |
|
|
|
for i in data[1]: |
|
|
|
assert i not in check_set |
|
|
|
check_set.add(i) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif( |
|
|
|
platform.system() == "Windows", |
|
|
|
reason="dataloader do not support parallel on windows", |
|
|
|
) |
|
|
|
def test_predataloader_parallel_worker_exception(): |
|
|
|
dataset = MyPreStream(100) |
|
|
|
|
|
|
|
class FakeErrorTransform(Transform): |
|
|
|
def __init__(self): |
|
|
|
pass |
|
|
|
|
|
|
|
def apply(self, input): |
|
|
|
raise RuntimeError("test raise error") |
|
|
|
return input |
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
dataset, sampler, num_workers=num_workers, timeout=2, timeout_event=cb |
|
|
|
dataset, |
|
|
|
sampler=StreamSampler(batch_size=4), |
|
|
|
transform=FakeErrorTransform(), |
|
|
|
num_workers=2, |
|
|
|
parallel_stream=True, |
|
|
|
) |
|
|
|
for _, data in enumerate(dataloader): |
|
|
|
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3))) |
|
|
|
np.testing.assert_equal(data[1], np.ones(shape=(4,))) |
|
|
|
break |
|
|
|
with pytest.raises(RuntimeError, match=r"exited unexpectedly"): |
|
|
|
data_iter = iter(dataloader) |
|
|
|
batch_data = next(data_iter) |
|
|
|
print(batch_data.shape) |