fix(mge/distributed): change api name with preload
fix(mge/distributed): fix recursive model in preload tensor
fix(mge/distributed): fix recursive when cache contain None
GitOrigin-RevId: 80e2a6dd70
release-1.6
@@ -15,12 +15,15 @@ import queue | |||||
import random | import random | ||||
import threading | import threading | ||||
import time | import time | ||||
from typing import Callable | |||||
from typing import Callable, Union | |||||
import numpy as np | import numpy as np | ||||
from ..device import _sh, get_default_device | |||||
from ..functional.tensor import copy | |||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..random.rng import _random_seed_generator | from ..random.rng import _random_seed_generator | ||||
from ..tensor import Tensor | |||||
from .collator import Collator | from .collator import Collator | ||||
from .dataset import Dataset, StreamDataset | from .dataset import Dataset, StreamDataset | ||||
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler | from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler | ||||
@@ -44,7 +47,7 @@ def raise_timeout_error(): | |||||
class DataLoader: | class DataLoader: | ||||
r"""Provides a convenient way to iterate on a given dataset. | r"""Provides a convenient way to iterate on a given dataset. | ||||
DataLoader combines a dataset with | DataLoader combines a dataset with | ||||
:class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`, | :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`, | ||||
make it flexible to get minibatch continually from a dataset. | make it flexible to get minibatch continually from a dataset. | ||||
@@ -66,6 +69,8 @@ class DataLoader: | |||||
``True`` means one batch is divided into :attr:`num_workers` pieces, and | ``True`` means one batch is divided into :attr:`num_workers` pieces, and | ||||
the workers will process these pieces parallelly. ``False`` means | the workers will process these pieces parallelly. ``False`` means | ||||
different sub-process will process different batch. Default: False | different sub-process will process different batch. Default: False | ||||
preload: Defines whether to apply the preloading strategy of dataloader, and parallelize the copy of host2device while kernal is executed to improve the loading speed. default is seted False | |||||
the output will change from np.ndarry to dtype tensor. the support dtypes for preload are int,float,list[int,float],tuple[int,float],and another type is not supported. | |||||
""" | """ | ||||
__initialized = False | __initialized = False | ||||
@@ -79,6 +84,7 @@ class DataLoader: | |||||
timeout: int = 0, | timeout: int = 0, | ||||
timeout_event: Callable = raise_timeout_error, | timeout_event: Callable = raise_timeout_error, | ||||
divide: bool = False, | divide: bool = False, | ||||
preload: bool = False, | |||||
): | ): | ||||
if num_workers < 0: | if num_workers < 0: | ||||
raise ValueError("num_workers should not be negative") | raise ValueError("num_workers should not be negative") | ||||
@@ -96,6 +102,7 @@ class DataLoader: | |||||
self.timeout_event = timeout_event | self.timeout_event = timeout_event | ||||
self.divide = divide | self.divide = divide | ||||
self.preload = preload | |||||
if isinstance(dataset, StreamDataset): | if isinstance(dataset, StreamDataset): | ||||
self.sampler = sampler if sampler else StreamSampler(batch_size=1) | self.sampler = sampler if sampler else StreamSampler(batch_size=1) | ||||
@@ -145,24 +152,74 @@ class DataLoader: | |||||
self.num_workers = 0 | self.num_workers = 0 | ||||
if isinstance(self.dataset, StreamDataset): | if isinstance(self.dataset, StreamDataset): | ||||
if not self.num_workers: | if not self.num_workers: | ||||
return _SerialStreamDataLoaderIter(self) | |||||
return _SerialStreamDataLoaderIter(self, self.preload) | |||||
else: | else: | ||||
return _ParallelStreamDataLoaderIter(self) | |||||
return _ParallelStreamDataLoaderIter(self, self.preload) | |||||
else: | else: | ||||
assert isinstance( | assert isinstance( | ||||
self.dataset, Dataset | self.dataset, Dataset | ||||
), "Can not recognize this kind of dataset: %s" % type(self.dataset) | ), "Can not recognize this kind of dataset: %s" % type(self.dataset) | ||||
if not self.num_workers: | if not self.num_workers: | ||||
return _SerialMapDataLoaderIter(self) | |||||
return _SerialMapDataLoaderIter(self, self.preload) | |||||
else: | else: | ||||
return _ParallelMapDataLoaderIter(self) | |||||
return _ParallelMapDataLoaderIter(self, self.preload) | |||||
def __len__(self): | def __len__(self): | ||||
return len(self.sampler) | return len(self.sampler) | ||||
class _BaseMapDataLoaderIter: | |||||
def __init__(self, loader): | |||||
class PreLoader: | |||||
def __init__(self, preload): | |||||
if preload: | |||||
self.default_device = get_default_device() | |||||
self.pre_load_device = self.default_device + ":" + str(_sh.get_next()) | |||||
self.pre_load_device_cache = None | |||||
self.preload = preload | |||||
""" | |||||
strategy one: load from numpy data, and generate dtype tensor | |||||
""" | |||||
def _load_tensor(self, batch, cached=True): | |||||
if isinstance(batch, np.ndarray): | |||||
device = self.pre_load_device if cached else self.default_device | |||||
return Tensor(batch, device=device) | |||||
elif isinstance(batch, collections.abc.Mapping): | |||||
return {k: self._load_tensor(v, cached) for k, v in batch.items()} | |||||
elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple | |||||
return type(batch)(*(self._load_tensor(value, cached) for value in batch)) | |||||
elif isinstance(batch, collections.abc.Sequence): | |||||
return [self._load_tensor(value, cached) for value in batch] | |||||
else: | |||||
return batch | |||||
""" | |||||
strategy two: load from cache that is already tensor just do d2d copy | |||||
""" | |||||
def _load_cache(self, data): | |||||
if isinstance(data, Tensor): | |||||
if data.device == self.default_device: | |||||
return data | |||||
return copy(data, device=self.default_device) | |||||
elif isinstance(data, collections.abc.Mapping): | |||||
return {k: self._load_cache(v) for k, v in data.items()} | |||||
elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple | |||||
return type(data)(*(self._load_cache(value) for value in data)) | |||||
elif isinstance(data, collections.abc.Sequence): | |||||
return [self._load_cache(value) for value in data] | |||||
else: | |||||
return data | |||||
def _swap_out_cache(self): | |||||
out = self._load_cache(self.pre_load_device_cache) | |||||
self.pre_load_device_cache = None # clean cache | |||||
return out | |||||
class _BaseMapDataLoaderIter(PreLoader): | |||||
def __init__(self, loader, preload): | |||||
super().__init__(preload) | |||||
self.dataset = loader.dataset | self.dataset = loader.dataset | ||||
self.sampler = loader.sampler | self.sampler = loader.sampler | ||||
self.seed = _random_seed_generator().__next__() | self.seed = _random_seed_generator().__next__() | ||||
@@ -184,16 +241,35 @@ class _BaseMapDataLoaderIter: | |||||
return self | return self | ||||
def __next__(self): | def __next__(self): | ||||
if self.preload: | |||||
cached = self.pre_load_device_cache | |||||
if cached is None: # first and last | |||||
if self.num_processed >= len(self): # last | |||||
raise StopIteration | |||||
elif self.num_processed == 0: # first | |||||
self._try_load_tensor(cached=False) # first do the h2d | |||||
out = self._swap_out_cache() | |||||
self._try_load_tensor() | |||||
return out | |||||
else: | |||||
if self.num_processed >= len(self): | |||||
raise StopIteration | |||||
minibatch = self._get_next_batch() | |||||
self.num_processed += 1 | |||||
return minibatch | |||||
def _try_load_tensor(self, cached=True): | |||||
if self.num_processed >= len(self): | if self.num_processed >= len(self): | ||||
raise StopIteration | |||||
minibatch = self._get_next_batch() | |||||
self.num_processed += 1 | |||||
return minibatch | |||||
return | |||||
else: | |||||
self.num_processed += 1 | |||||
batch = self._get_next_batch() | |||||
self.pre_load_device_cache = self._load_tensor(batch, cached) | |||||
class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): | class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): | ||||
def __init__(self, loader): | |||||
super(_SerialMapDataLoaderIter, self).__init__(loader) | |||||
def __init__(self, loader, preload): | |||||
super(_SerialMapDataLoaderIter, self).__init__(loader, preload) | |||||
self.indices_iter = iter(self.sampler) | self.indices_iter = iter(self.sampler) | ||||
def _get_next_batch(self): | def _get_next_batch(self): | ||||
@@ -206,8 +282,8 @@ class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): | |||||
class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): | class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): | ||||
__initialized = False | __initialized = False | ||||
def __init__(self, loader): | |||||
super(_ParallelMapDataLoaderIter, self).__init__(loader) | |||||
def __init__(self, loader, preload): | |||||
super(_ParallelMapDataLoaderIter, self).__init__(loader, preload) | |||||
self.task_queues = [ | self.task_queues = [ | ||||
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) | multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) | ||||
@@ -358,8 +434,9 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): | |||||
self._shutdown() | self._shutdown() | ||||
class _BaseStreamDataLoaderIter: | |||||
def __init__(self, loader): | |||||
class _BaseStreamDataLoaderIter(PreLoader): | |||||
def __init__(self, loader, preload): | |||||
super().__init__(preload) | |||||
self.dataset = loader.dataset | self.dataset = loader.dataset | ||||
self.sampler = loader.sampler | self.sampler = loader.sampler | ||||
self.transform = loader.transform | self.transform = loader.transform | ||||
@@ -388,12 +465,23 @@ class _BaseStreamDataLoaderIter: | |||||
return self | return self | ||||
def __next__(self): | def __next__(self): | ||||
return self._get_next_batch() | |||||
if self.preload: | |||||
if self.pre_load_device_cache is None: | |||||
self._try_load_tensor(cached=False) # load in current | |||||
out = self._swap_out_cache() | |||||
self._try_load_tensor() # load in cached | |||||
return out | |||||
else: | |||||
return self._get_next_batch() | |||||
def _try_load_tensor(self, cached=True): | |||||
batch = self._get_next_batch() | |||||
self.pre_load_device_cache = self._load_tensor(batch, cached) | |||||
class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): | class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): | ||||
def __init__(self, loader): | |||||
super().__init__(loader) | |||||
def __init__(self, loader, preload): | |||||
super().__init__(loader, preload) | |||||
self.dataset_iter = iter(self.dataset) | self.dataset_iter = iter(self.dataset) | ||||
self.idx = 0 | self.idx = 0 | ||||
self.unused = [] | self.unused = [] | ||||
@@ -439,8 +527,8 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): | |||||
class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): | class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): | ||||
__initialized = False | __initialized = False | ||||
def __init__(self, loader): | |||||
super().__init__(loader) | |||||
def __init__(self, loader, preload): | |||||
super().__init__(loader, preload) | |||||
self.shutdown_flag = multiprocessing.Value("i", 0) | self.shutdown_flag = multiprocessing.Value("i", 0) | ||||
@@ -29,6 +29,19 @@ __all__ = [ | |||||
] | ] | ||||
class _stream_helper: | |||||
def __init__(self): | |||||
self.stream = 1 | |||||
def get_next(self): | |||||
out = self.stream | |||||
self.stream = self.stream + 1 | |||||
return out | |||||
_sh = _stream_helper() | |||||
def _valid_device(inp): | def _valid_device(inp): | ||||
if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp): | if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp): | ||||
return True | return True | ||||
@@ -12,7 +12,7 @@ from typing import List, Optional, Tuple | |||||
from mprop import mproperty | from mprop import mproperty | ||||
from ..device import set_default_device, what_is_xpu | |||||
from ..device import _sh, set_default_device, what_is_xpu | |||||
from ..random import seed | from ..random import seed | ||||
from .server import Client, Server | from .server import Client, Server | ||||
@@ -27,7 +27,6 @@ class StaticData: | |||||
proc_rank = None | proc_rank = None | ||||
device = None | device = None | ||||
backend = None | backend = None | ||||
next_stream = None | |||||
device_type = None | device_type = None | ||||
machine_ranks = None | machine_ranks = None | ||||
@@ -43,6 +42,8 @@ class Group: | |||||
Args: | Args: | ||||
proc_ranks: rank list of the group, the first one is root rank. | proc_ranks: rank list of the group, the first one is root rank. | ||||
""" | """ | ||||
def __init__(self, proc_ranks): | def __init__(self, proc_ranks): | ||||
@@ -55,9 +56,7 @@ class Group: | |||||
def reset(self, proc_ranks): | def reset(self, proc_ranks): | ||||
self.check(proc_ranks) | self.check(proc_ranks) | ||||
self.proc_ranks = proc_ranks | self.proc_ranks = proc_ranks | ||||
self.stream = _sd.next_stream | |||||
_sd.next_stream += 1 | |||||
self.is_single_machine_cache = None | |||||
self.stream = _sh.get_next() | |||||
def check(self, proc_ranks): | def check(self, proc_ranks): | ||||
assert _sd is not None, "please call init_process_group first" | assert _sd is not None, "please call init_process_group first" | ||||
@@ -160,7 +159,6 @@ def init_process_group( | |||||
_sd.proc_rank = rank | _sd.proc_rank = rank | ||||
_sd.device = device | _sd.device = device | ||||
_sd.backend = backend | _sd.backend = backend | ||||
_sd.next_stream = 1 | |||||
_sd.device_type = device_type | _sd.device_type = device_type | ||||
WORLD.reset(list(range(world_size))) | WORLD.reset(list(range(world_size))) | ||||
@@ -0,0 +1,308 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import gc | |||||
import os | |||||
import platform | |||||
import time | |||||
import numpy as np | |||||
import pytest | |||||
from megengine.data.collator import Collator | |||||
from megengine.data.dataloader import DataLoader | |||||
from megengine.data.dataset import ArrayDataset, StreamDataset | |||||
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler | |||||
from megengine.data.transform import ( | |||||
Compose, | |||||
Normalize, | |||||
PseudoTransform, | |||||
ToMode, | |||||
Transform, | |||||
) | |||||
def init_dataset(): | |||||
sample_num = 100 | |||||
rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8) | |||||
label = np.random.randint(0, 10, size=(sample_num,), dtype=int) | |||||
dataset = ArrayDataset(rand_data, label) | |||||
return dataset | |||||
def test_dataloader_init(): | |||||
dataset = init_dataset() | |||||
with pytest.raises(ValueError): | |||||
dataloader = DataLoader(dataset, num_workers=2, divide=True) | |||||
with pytest.raises(ValueError): | |||||
dataloader = DataLoader(dataset, num_workers=-1) | |||||
with pytest.raises(ValueError): | |||||
dataloader = DataLoader(dataset, timeout=-1) | |||||
with pytest.raises(ValueError): | |||||
dataloader = DataLoader(dataset, num_workers=0, divide=True) | |||||
dataloader = DataLoader(dataset, preload=True) | |||||
assert isinstance(dataloader.sampler, SequentialSampler) | |||||
assert isinstance(dataloader.transform, PseudoTransform) | |||||
assert isinstance(dataloader.collator, Collator) | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, batch_size=6, drop_last=False), | |||||
preload=True, | |||||
) | |||||
assert len(dataloader) == 17 | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, batch_size=6, drop_last=True), | |||||
preload=True, | |||||
) | |||||
assert len(dataloader) == 16 | |||||
class MyStream(StreamDataset): | |||||
def __init__(self, number, batch=False, error_foramt=False, block=False): | |||||
self.number = number | |||||
self.batch = batch | |||||
self.error_format = error_foramt | |||||
self.block = block | |||||
def __iter__(self): | |||||
for cnt in range(self.number): | |||||
if self.block: | |||||
for _ in range(10): | |||||
time.sleep(1) | |||||
if self.batch: | |||||
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8") | |||||
yield (True, (data, [cnt, cnt - self.number])) | |||||
else: | |||||
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8") | |||||
if self.error_format: | |||||
yield (data, cnt) | |||||
else: | |||||
yield (False, (data, cnt)) | |||||
raise StopIteration | |||||
@pytest.mark.parametrize("batch", [True, False]) | |||||
@pytest.mark.parametrize("num_workers", [0, 2]) | |||||
def test_stream_dataloader(batch, num_workers): | |||||
dataset = MyStream(100, batch=batch) | |||||
sampler = StreamSampler(batch_size=4) | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler, | |||||
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]), | |||||
num_workers=num_workers, | |||||
preload=True, | |||||
) | |||||
check_set = set() | |||||
for step, data in enumerate(dataloader): | |||||
if step == 10: | |||||
break | |||||
assert data[0]._tuple_shape == (4, 3, 2, 2) | |||||
assert data[1]._tuple_shape == (4,) | |||||
for i in data[1]: | |||||
assert i not in check_set | |||||
check_set.add(i) | |||||
def test_stream_dataloader_error(): | |||||
dataset = MyStream(100, error_foramt=True) | |||||
sampler = StreamSampler(batch_size=4) | |||||
dataloader = DataLoader(dataset, sampler, preload=True) | |||||
with pytest.raises(AssertionError, match=r".*tuple.*"): | |||||
data_iter = iter(dataloader) | |||||
next(data_iter) | |||||
@pytest.mark.parametrize("num_workers", [0, 2]) | |||||
def test_stream_dataloader_timeout(num_workers): | |||||
dataset = MyStream(100, False, block=True) | |||||
sampler = StreamSampler(batch_size=4) | |||||
dataloader = DataLoader( | |||||
dataset, sampler, num_workers=num_workers, timeout=2, preload=True | |||||
) | |||||
with pytest.raises(RuntimeError, match=r".*timeout.*"): | |||||
data_iter = iter(dataloader) | |||||
next(data_iter) | |||||
def test_dataloader_serial(): | |||||
dataset = init_dataset() | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
preload=True, | |||||
) | |||||
for (data, label) in dataloader: | |||||
assert data._tuple_shape == (4, 1, 32, 32) | |||||
assert label._tuple_shape == (4,) | |||||
def test_dataloader_parallel(): | |||||
# set max shared memory to 100M | |||||
os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||||
dataset = init_dataset() | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
num_workers=2, | |||||
divide=False, | |||||
preload=True, | |||||
) | |||||
for (data, label) in dataloader: | |||||
assert data._tuple_shape == (4, 1, 32, 32) | |||||
assert label._tuple_shape == (4,) | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
num_workers=2, | |||||
divide=True, | |||||
preload=True, | |||||
) | |||||
for (data, label) in dataloader: | |||||
assert data._tuple_shape == (4, 1, 32, 32) | |||||
assert label._tuple_shape == (4,) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Windows", | |||||
reason="dataloader do not support parallel on windows", | |||||
) | |||||
def test_dataloader_parallel_timeout(): | |||||
dataset = init_dataset() | |||||
class TimeoutTransform(Transform): | |||||
def __init__(self): | |||||
pass | |||||
def apply(self, input): | |||||
time.sleep(10) | |||||
return input | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
transform=TimeoutTransform(), | |||||
num_workers=2, | |||||
timeout=2, | |||||
preload=True, | |||||
) | |||||
with pytest.raises(RuntimeError, match=r".*timeout.*"): | |||||
data_iter = iter(dataloader) | |||||
batch_data = next(data_iter) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Windows", | |||||
reason="dataloader do not support parallel on windows", | |||||
) | |||||
def test_dataloader_parallel_worker_exception(): | |||||
print("in target") | |||||
dataset = init_dataset() | |||||
class FakeErrorTransform(Transform): | |||||
def __init__(self): | |||||
pass | |||||
def apply(self, input): | |||||
y = x + 1 | |||||
return input | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
transform=FakeErrorTransform(), | |||||
num_workers=2, | |||||
preload=True, | |||||
) | |||||
with pytest.raises(RuntimeError, match=r"worker.*died"): | |||||
data_iter = iter(dataloader) | |||||
batch_data = next(data_iter) | |||||
def _multi_instances_parallel_dataloader_worker(): | |||||
dataset = init_dataset() | |||||
for divide_flag in [True, False]: | |||||
train_dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
num_workers=2, | |||||
divide=divide_flag, | |||||
preload=True, | |||||
) | |||||
val_dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, batch_size=10, drop_last=False), | |||||
num_workers=2, | |||||
divide=divide_flag, | |||||
preload=True, | |||||
) | |||||
for idx, (data, label) in enumerate(train_dataloader): | |||||
assert data._tuple_shape == (4, 1, 32, 32) | |||||
assert label._tuple_shape == (4,) | |||||
if idx % 5 == 0: | |||||
for val_data, val_label in val_dataloader: | |||||
assert val_data._tuple_shape == (10, 1, 32, 32) | |||||
assert val_label._tuple_shape == (10,) | |||||
def test_dataloader_parallel_multi_instances(): | |||||
# set max shared memory to 100M | |||||
os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||||
_multi_instances_parallel_dataloader_worker() | |||||
@pytest.mark.isolated_distributed | |||||
def test_dataloader_parallel_multi_instances_multiprocessing(): | |||||
gc.collect() | |||||
# set max shared memory to 100M | |||||
os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||||
import multiprocessing as mp | |||||
# mp.set_start_method("spawn") | |||||
processes = [] | |||||
for i in range(4): | |||||
p = mp.Process(target=_multi_instances_parallel_dataloader_worker) | |||||
p.start() | |||||
processes.append(p) | |||||
for p in processes: | |||||
p.join() | |||||
assert p.exitcode == 0 | |||||
@pytest.mark.parametrize("num_workers", [0, 2]) | |||||
def test_timeout_event(num_workers): | |||||
def cb(): | |||||
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,)))) | |||||
dataset = MyStream(100, block=True) | |||||
sampler = StreamSampler(batch_size=4) | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler, | |||||
num_workers=num_workers, | |||||
timeout=2, | |||||
timeout_event=cb, | |||||
preload=True, | |||||
) | |||||
for _, data in enumerate(dataloader): | |||||
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3))) | |||||
np.testing.assert_equal(data[1], np.ones(shape=(4,))) | |||||
break |