|
|
@@ -132,3 +132,52 @@ def test_dataloader_parallel_worker_exception(): |
|
|
|
with pytest.raises(RuntimeError, match=r"worker.*died"): |
|
|
|
data_iter = iter(dataloader) |
|
|
|
batch_data = next(data_iter) |
|
|
|
|
|
|
|
|
|
|
|
def _multi_instances_parallel_dataloader_worker(): |
|
|
|
dataset = init_dataset() |
|
|
|
|
|
|
|
for divide_flag in [True, False]: |
|
|
|
train_dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), |
|
|
|
num_workers=2, |
|
|
|
divide=divide_flag, |
|
|
|
) |
|
|
|
val_dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
sampler=RandomSampler(dataset, batch_size=10, drop_last=False), |
|
|
|
num_workers=2, |
|
|
|
divide=divide_flag, |
|
|
|
) |
|
|
|
for idx, (data, label) in enumerate(train_dataloader): |
|
|
|
assert data.shape == (4, 1, 32, 32) |
|
|
|
assert label.shape == (4,) |
|
|
|
if idx % 5 == 0: |
|
|
|
for val_data, val_label in val_dataloader: |
|
|
|
assert val_data.shape == (10, 1, 32, 32) |
|
|
|
assert val_label.shape == (10,) |
|
|
|
|
|
|
|
|
|
|
|
def test_dataloader_parallel_multi_instances(): |
|
|
|
# set max shared memory to 100M |
|
|
|
os.environ["MGE_PLASMA_MEMORY"] = "100000000" |
|
|
|
|
|
|
|
_multi_instances_parallel_dataloader_worker() |
|
|
|
|
|
|
|
|
|
|
|
def test_dataloader_parallel_multi_instances_multiprocessing(): |
|
|
|
# set max shared memory to 100M |
|
|
|
os.environ["MGE_PLASMA_MEMORY"] = "100000000" |
|
|
|
|
|
|
|
import multiprocessing as mp |
|
|
|
|
|
|
|
# mp.set_start_method("spawn") |
|
|
|
processes = [] |
|
|
|
for i in range(4): |
|
|
|
p = mp.Process(target=_multi_instances_parallel_dataloader_worker) |
|
|
|
p.start() |
|
|
|
processes.append(p) |
|
|
|
|
|
|
|
for p in processes: |
|
|
|
p.join() |