|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import os
- import time
-
- import numpy as np
- import pytest
-
- from megengine.data.collator import Collator
- from megengine.data.dataloader import DataLoader
- from megengine.data.dataset import ArrayDataset
- from megengine.data.sampler import RandomSampler, SequentialSampler
- from megengine.data.transform import PseudoTransform, Transform
-
-
- def init_dataset():
- sample_num = 100
- rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8)
- label = np.random.randint(0, 10, size=(sample_num,), dtype=int)
- dataset = ArrayDataset(rand_data, label)
- return dataset
-
-
- def test_dataloader_init():
- dataset = init_dataset()
- with pytest.raises(ValueError):
- dataloader = DataLoader(dataset, num_workers=2, divide=True)
- with pytest.raises(ValueError):
- dataloader = DataLoader(dataset, num_workers=-1)
- with pytest.raises(ValueError):
- dataloader = DataLoader(dataset, timeout=-1)
- with pytest.raises(ValueError):
- dataloader = DataLoader(dataset, num_workers=0, divide=True)
-
- dataloader = DataLoader(dataset)
- assert isinstance(dataloader.sampler, SequentialSampler)
- assert isinstance(dataloader.transform, PseudoTransform)
- assert isinstance(dataloader.collator, Collator)
-
- dataloader = DataLoader(
- dataset, sampler=RandomSampler(dataset, batch_size=6, drop_last=False)
- )
- assert len(dataloader) == 17
- dataloader = DataLoader(
- dataset, sampler=RandomSampler(dataset, batch_size=6, drop_last=True)
- )
- assert len(dataloader) == 16
-
-
- def test_dataloader_serial():
- dataset = init_dataset()
- dataloader = DataLoader(
- dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False)
- )
- for (data, label) in dataloader:
- assert data.shape == (4, 1, 32, 32)
- assert label.shape == (4,)
-
-
- def test_dataloader_parallel():
- # set max shared memory to 100M
- os.environ["MGE_PLASMA_MEMORY"] = "100000000"
-
- dataset = init_dataset()
- dataloader = DataLoader(
- dataset,
- sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
- num_workers=2,
- divide=False,
- )
- for (data, label) in dataloader:
- assert data.shape == (4, 1, 32, 32)
- assert label.shape == (4,)
-
- dataloader = DataLoader(
- dataset,
- sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
- num_workers=2,
- divide=True,
- )
- for (data, label) in dataloader:
- assert data.shape == (4, 1, 32, 32)
- assert label.shape == (4,)
-
-
- def test_dataloader_parallel_timeout():
- dataset = init_dataset()
-
- class TimeoutTransform(Transform):
- def __init__(self):
- pass
-
- def apply(self, input):
- time.sleep(10)
- return input
-
- dataloader = DataLoader(
- dataset,
- sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
- transform=TimeoutTransform(),
- num_workers=2,
- timeout=2,
- )
- with pytest.raises(RuntimeError, match=r".*timeout.*"):
- data_iter = iter(dataloader)
- batch_data = next(data_iter)
-
-
- def test_dataloader_parallel_worker_exception():
- dataset = init_dataset()
-
- class FakeErrorTransform(Transform):
- def __init__(self):
- pass
-
- def apply(self, input):
- y = x + 1
- return input
-
- dataloader = DataLoader(
- dataset,
- sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
- transform=FakeErrorTransform(),
- num_workers=2,
- )
- with pytest.raises(RuntimeError, match=r"worker.*died"):
- data_iter = iter(dataloader)
- batch_data = next(data_iter)
-
-
- def _multi_instances_parallel_dataloader_worker():
- dataset = init_dataset()
-
- for divide_flag in [True, False]:
- train_dataloader = DataLoader(
- dataset,
- sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
- num_workers=2,
- divide=divide_flag,
- )
- val_dataloader = DataLoader(
- dataset,
- sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
- num_workers=2,
- divide=divide_flag,
- )
- for idx, (data, label) in enumerate(train_dataloader):
- assert data.shape == (4, 1, 32, 32)
- assert label.shape == (4,)
- if idx % 5 == 0:
- for val_data, val_label in val_dataloader:
- assert val_data.shape == (10, 1, 32, 32)
- assert val_label.shape == (10,)
-
-
- def test_dataloader_parallel_multi_instances():
- # set max shared memory to 100M
- os.environ["MGE_PLASMA_MEMORY"] = "100000000"
-
- _multi_instances_parallel_dataloader_worker()
-
-
- def test_dataloader_parallel_multi_instances_multiprocessing():
- # set max shared memory to 100M
- os.environ["MGE_PLASMA_MEMORY"] = "100000000"
-
- import multiprocessing as mp
-
- # mp.set_start_method("spawn")
- processes = []
- for i in range(4):
- p = mp.Process(target=_multi_instances_parallel_dataloader_worker)
- p.start()
- processes.append(p)
-
- for p in processes:
- p.join()
|