fix(mge/distributed): change api name with preload
fix(mge/distributed): fix recursive model in preload tensor
fix(mge/distributed): fix recursive when cache contain None
GitOrigin-RevId: 80e2a6dd70
release-1.6
@@ -15,12 +15,15 @@ import queue | |||
import random | |||
import threading | |||
import time | |||
from typing import Callable | |||
from typing import Callable, Union | |||
import numpy as np | |||
from ..device import _sh, get_default_device | |||
from ..functional.tensor import copy | |||
from ..logger import get_logger | |||
from ..random.rng import _random_seed_generator | |||
from ..tensor import Tensor | |||
from .collator import Collator | |||
from .dataset import Dataset, StreamDataset | |||
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler | |||
@@ -44,7 +47,7 @@ def raise_timeout_error(): | |||
class DataLoader: | |||
r"""Provides a convenient way to iterate on a given dataset. | |||
DataLoader combines a dataset with | |||
:class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`, | |||
make it flexible to get minibatch continually from a dataset. | |||
@@ -66,6 +69,8 @@ class DataLoader: | |||
``True`` means one batch is divided into :attr:`num_workers` pieces, and | |||
the workers will process these pieces parallelly. ``False`` means | |||
different sub-process will process different batch. Default: False | |||
preload: Defines whether to apply the preloading strategy of dataloader, and parallelize the copy of host2device while kernal is executed to improve the loading speed. default is seted False | |||
the output will change from np.ndarry to dtype tensor. the support dtypes for preload are int,float,list[int,float],tuple[int,float],and another type is not supported. | |||
""" | |||
__initialized = False | |||
@@ -79,6 +84,7 @@ class DataLoader: | |||
timeout: int = 0, | |||
timeout_event: Callable = raise_timeout_error, | |||
divide: bool = False, | |||
preload: bool = False, | |||
): | |||
if num_workers < 0: | |||
raise ValueError("num_workers should not be negative") | |||
@@ -96,6 +102,7 @@ class DataLoader: | |||
self.timeout_event = timeout_event | |||
self.divide = divide | |||
self.preload = preload | |||
if isinstance(dataset, StreamDataset): | |||
self.sampler = sampler if sampler else StreamSampler(batch_size=1) | |||
@@ -145,24 +152,74 @@ class DataLoader: | |||
self.num_workers = 0 | |||
if isinstance(self.dataset, StreamDataset): | |||
if not self.num_workers: | |||
return _SerialStreamDataLoaderIter(self) | |||
return _SerialStreamDataLoaderIter(self, self.preload) | |||
else: | |||
return _ParallelStreamDataLoaderIter(self) | |||
return _ParallelStreamDataLoaderIter(self, self.preload) | |||
else: | |||
assert isinstance( | |||
self.dataset, Dataset | |||
), "Can not recognize this kind of dataset: %s" % type(self.dataset) | |||
if not self.num_workers: | |||
return _SerialMapDataLoaderIter(self) | |||
return _SerialMapDataLoaderIter(self, self.preload) | |||
else: | |||
return _ParallelMapDataLoaderIter(self) | |||
return _ParallelMapDataLoaderIter(self, self.preload) | |||
def __len__(self): | |||
return len(self.sampler) | |||
class _BaseMapDataLoaderIter: | |||
def __init__(self, loader): | |||
class PreLoader: | |||
def __init__(self, preload): | |||
if preload: | |||
self.default_device = get_default_device() | |||
self.pre_load_device = self.default_device + ":" + str(_sh.get_next()) | |||
self.pre_load_device_cache = None | |||
self.preload = preload | |||
""" | |||
strategy one: load from numpy data, and generate dtype tensor | |||
""" | |||
def _load_tensor(self, batch, cached=True): | |||
if isinstance(batch, np.ndarray): | |||
device = self.pre_load_device if cached else self.default_device | |||
return Tensor(batch, device=device) | |||
elif isinstance(batch, collections.abc.Mapping): | |||
return {k: self._load_tensor(v, cached) for k, v in batch.items()} | |||
elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple | |||
return type(batch)(*(self._load_tensor(value, cached) for value in batch)) | |||
elif isinstance(batch, collections.abc.Sequence): | |||
return [self._load_tensor(value, cached) for value in batch] | |||
else: | |||
return batch | |||
""" | |||
strategy two: load from cache that is already tensor just do d2d copy | |||
""" | |||
def _load_cache(self, data): | |||
if isinstance(data, Tensor): | |||
if data.device == self.default_device: | |||
return data | |||
return copy(data, device=self.default_device) | |||
elif isinstance(data, collections.abc.Mapping): | |||
return {k: self._load_cache(v) for k, v in data.items()} | |||
elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple | |||
return type(data)(*(self._load_cache(value) for value in data)) | |||
elif isinstance(data, collections.abc.Sequence): | |||
return [self._load_cache(value) for value in data] | |||
else: | |||
return data | |||
def _swap_out_cache(self): | |||
out = self._load_cache(self.pre_load_device_cache) | |||
self.pre_load_device_cache = None # clean cache | |||
return out | |||
class _BaseMapDataLoaderIter(PreLoader): | |||
def __init__(self, loader, preload): | |||
super().__init__(preload) | |||
self.dataset = loader.dataset | |||
self.sampler = loader.sampler | |||
self.seed = _random_seed_generator().__next__() | |||
@@ -184,16 +241,35 @@ class _BaseMapDataLoaderIter: | |||
return self | |||
def __next__(self): | |||
if self.preload: | |||
cached = self.pre_load_device_cache | |||
if cached is None: # first and last | |||
if self.num_processed >= len(self): # last | |||
raise StopIteration | |||
elif self.num_processed == 0: # first | |||
self._try_load_tensor(cached=False) # first do the h2d | |||
out = self._swap_out_cache() | |||
self._try_load_tensor() | |||
return out | |||
else: | |||
if self.num_processed >= len(self): | |||
raise StopIteration | |||
minibatch = self._get_next_batch() | |||
self.num_processed += 1 | |||
return minibatch | |||
def _try_load_tensor(self, cached=True): | |||
if self.num_processed >= len(self): | |||
raise StopIteration | |||
minibatch = self._get_next_batch() | |||
self.num_processed += 1 | |||
return minibatch | |||
return | |||
else: | |||
self.num_processed += 1 | |||
batch = self._get_next_batch() | |||
self.pre_load_device_cache = self._load_tensor(batch, cached) | |||
class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): | |||
def __init__(self, loader): | |||
super(_SerialMapDataLoaderIter, self).__init__(loader) | |||
def __init__(self, loader, preload): | |||
super(_SerialMapDataLoaderIter, self).__init__(loader, preload) | |||
self.indices_iter = iter(self.sampler) | |||
def _get_next_batch(self): | |||
@@ -206,8 +282,8 @@ class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): | |||
class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): | |||
__initialized = False | |||
def __init__(self, loader): | |||
super(_ParallelMapDataLoaderIter, self).__init__(loader) | |||
def __init__(self, loader, preload): | |||
super(_ParallelMapDataLoaderIter, self).__init__(loader, preload) | |||
self.task_queues = [ | |||
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) | |||
@@ -358,8 +434,9 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): | |||
self._shutdown() | |||
class _BaseStreamDataLoaderIter: | |||
def __init__(self, loader): | |||
class _BaseStreamDataLoaderIter(PreLoader): | |||
def __init__(self, loader, preload): | |||
super().__init__(preload) | |||
self.dataset = loader.dataset | |||
self.sampler = loader.sampler | |||
self.transform = loader.transform | |||
@@ -388,12 +465,23 @@ class _BaseStreamDataLoaderIter: | |||
return self | |||
def __next__(self): | |||
return self._get_next_batch() | |||
if self.preload: | |||
if self.pre_load_device_cache is None: | |||
self._try_load_tensor(cached=False) # load in current | |||
out = self._swap_out_cache() | |||
self._try_load_tensor() # load in cached | |||
return out | |||
else: | |||
return self._get_next_batch() | |||
def _try_load_tensor(self, cached=True): | |||
batch = self._get_next_batch() | |||
self.pre_load_device_cache = self._load_tensor(batch, cached) | |||
class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): | |||
def __init__(self, loader): | |||
super().__init__(loader) | |||
def __init__(self, loader, preload): | |||
super().__init__(loader, preload) | |||
self.dataset_iter = iter(self.dataset) | |||
self.idx = 0 | |||
self.unused = [] | |||
@@ -439,8 +527,8 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): | |||
class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): | |||
__initialized = False | |||
def __init__(self, loader): | |||
super().__init__(loader) | |||
def __init__(self, loader, preload): | |||
super().__init__(loader, preload) | |||
self.shutdown_flag = multiprocessing.Value("i", 0) | |||
@@ -29,6 +29,19 @@ __all__ = [ | |||
] | |||
class _stream_helper: | |||
def __init__(self): | |||
self.stream = 1 | |||
def get_next(self): | |||
out = self.stream | |||
self.stream = self.stream + 1 | |||
return out | |||
_sh = _stream_helper() | |||
def _valid_device(inp): | |||
if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp): | |||
return True | |||
@@ -12,7 +12,7 @@ from typing import List, Optional, Tuple | |||
from mprop import mproperty | |||
from ..device import set_default_device, what_is_xpu | |||
from ..device import _sh, set_default_device, what_is_xpu | |||
from ..random import seed | |||
from .server import Client, Server | |||
@@ -27,7 +27,6 @@ class StaticData: | |||
proc_rank = None | |||
device = None | |||
backend = None | |||
next_stream = None | |||
device_type = None | |||
machine_ranks = None | |||
@@ -43,6 +42,8 @@ class Group: | |||
Args: | |||
proc_ranks: rank list of the group, the first one is root rank. | |||
""" | |||
def __init__(self, proc_ranks): | |||
@@ -55,9 +56,7 @@ class Group: | |||
def reset(self, proc_ranks): | |||
self.check(proc_ranks) | |||
self.proc_ranks = proc_ranks | |||
self.stream = _sd.next_stream | |||
_sd.next_stream += 1 | |||
self.is_single_machine_cache = None | |||
self.stream = _sh.get_next() | |||
def check(self, proc_ranks): | |||
assert _sd is not None, "please call init_process_group first" | |||
@@ -160,7 +159,6 @@ def init_process_group( | |||
_sd.proc_rank = rank | |||
_sd.device = device | |||
_sd.backend = backend | |||
_sd.next_stream = 1 | |||
_sd.device_type = device_type | |||
WORLD.reset(list(range(world_size))) | |||
@@ -0,0 +1,308 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import gc | |||
import os | |||
import platform | |||
import time | |||
import numpy as np | |||
import pytest | |||
from megengine.data.collator import Collator | |||
from megengine.data.dataloader import DataLoader | |||
from megengine.data.dataset import ArrayDataset, StreamDataset | |||
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler | |||
from megengine.data.transform import ( | |||
Compose, | |||
Normalize, | |||
PseudoTransform, | |||
ToMode, | |||
Transform, | |||
) | |||
def init_dataset(): | |||
sample_num = 100 | |||
rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8) | |||
label = np.random.randint(0, 10, size=(sample_num,), dtype=int) | |||
dataset = ArrayDataset(rand_data, label) | |||
return dataset | |||
def test_dataloader_init(): | |||
dataset = init_dataset() | |||
with pytest.raises(ValueError): | |||
dataloader = DataLoader(dataset, num_workers=2, divide=True) | |||
with pytest.raises(ValueError): | |||
dataloader = DataLoader(dataset, num_workers=-1) | |||
with pytest.raises(ValueError): | |||
dataloader = DataLoader(dataset, timeout=-1) | |||
with pytest.raises(ValueError): | |||
dataloader = DataLoader(dataset, num_workers=0, divide=True) | |||
dataloader = DataLoader(dataset, preload=True) | |||
assert isinstance(dataloader.sampler, SequentialSampler) | |||
assert isinstance(dataloader.transform, PseudoTransform) | |||
assert isinstance(dataloader.collator, Collator) | |||
dataloader = DataLoader( | |||
dataset, | |||
sampler=RandomSampler(dataset, batch_size=6, drop_last=False), | |||
preload=True, | |||
) | |||
assert len(dataloader) == 17 | |||
dataloader = DataLoader( | |||
dataset, | |||
sampler=RandomSampler(dataset, batch_size=6, drop_last=True), | |||
preload=True, | |||
) | |||
assert len(dataloader) == 16 | |||
class MyStream(StreamDataset): | |||
def __init__(self, number, batch=False, error_foramt=False, block=False): | |||
self.number = number | |||
self.batch = batch | |||
self.error_format = error_foramt | |||
self.block = block | |||
def __iter__(self): | |||
for cnt in range(self.number): | |||
if self.block: | |||
for _ in range(10): | |||
time.sleep(1) | |||
if self.batch: | |||
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8") | |||
yield (True, (data, [cnt, cnt - self.number])) | |||
else: | |||
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8") | |||
if self.error_format: | |||
yield (data, cnt) | |||
else: | |||
yield (False, (data, cnt)) | |||
raise StopIteration | |||
@pytest.mark.parametrize("batch", [True, False]) | |||
@pytest.mark.parametrize("num_workers", [0, 2]) | |||
def test_stream_dataloader(batch, num_workers): | |||
dataset = MyStream(100, batch=batch) | |||
sampler = StreamSampler(batch_size=4) | |||
dataloader = DataLoader( | |||
dataset, | |||
sampler, | |||
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]), | |||
num_workers=num_workers, | |||
preload=True, | |||
) | |||
check_set = set() | |||
for step, data in enumerate(dataloader): | |||
if step == 10: | |||
break | |||
assert data[0]._tuple_shape == (4, 3, 2, 2) | |||
assert data[1]._tuple_shape == (4,) | |||
for i in data[1]: | |||
assert i not in check_set | |||
check_set.add(i) | |||
def test_stream_dataloader_error(): | |||
dataset = MyStream(100, error_foramt=True) | |||
sampler = StreamSampler(batch_size=4) | |||
dataloader = DataLoader(dataset, sampler, preload=True) | |||
with pytest.raises(AssertionError, match=r".*tuple.*"): | |||
data_iter = iter(dataloader) | |||
next(data_iter) | |||
@pytest.mark.parametrize("num_workers", [0, 2]) | |||
def test_stream_dataloader_timeout(num_workers): | |||
dataset = MyStream(100, False, block=True) | |||
sampler = StreamSampler(batch_size=4) | |||
dataloader = DataLoader( | |||
dataset, sampler, num_workers=num_workers, timeout=2, preload=True | |||
) | |||
with pytest.raises(RuntimeError, match=r".*timeout.*"): | |||
data_iter = iter(dataloader) | |||
next(data_iter) | |||
def test_dataloader_serial(): | |||
dataset = init_dataset() | |||
dataloader = DataLoader( | |||
dataset, | |||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||
preload=True, | |||
) | |||
for (data, label) in dataloader: | |||
assert data._tuple_shape == (4, 1, 32, 32) | |||
assert label._tuple_shape == (4,) | |||
def test_dataloader_parallel(): | |||
# set max shared memory to 100M | |||
os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||
dataset = init_dataset() | |||
dataloader = DataLoader( | |||
dataset, | |||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||
num_workers=2, | |||
divide=False, | |||
preload=True, | |||
) | |||
for (data, label) in dataloader: | |||
assert data._tuple_shape == (4, 1, 32, 32) | |||
assert label._tuple_shape == (4,) | |||
dataloader = DataLoader( | |||
dataset, | |||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||
num_workers=2, | |||
divide=True, | |||
preload=True, | |||
) | |||
for (data, label) in dataloader: | |||
assert data._tuple_shape == (4, 1, 32, 32) | |||
assert label._tuple_shape == (4,) | |||
@pytest.mark.skipif( | |||
platform.system() == "Windows", | |||
reason="dataloader do not support parallel on windows", | |||
) | |||
def test_dataloader_parallel_timeout(): | |||
dataset = init_dataset() | |||
class TimeoutTransform(Transform): | |||
def __init__(self): | |||
pass | |||
def apply(self, input): | |||
time.sleep(10) | |||
return input | |||
dataloader = DataLoader( | |||
dataset, | |||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||
transform=TimeoutTransform(), | |||
num_workers=2, | |||
timeout=2, | |||
preload=True, | |||
) | |||
with pytest.raises(RuntimeError, match=r".*timeout.*"): | |||
data_iter = iter(dataloader) | |||
batch_data = next(data_iter) | |||
@pytest.mark.skipif( | |||
platform.system() == "Windows", | |||
reason="dataloader do not support parallel on windows", | |||
) | |||
def test_dataloader_parallel_worker_exception(): | |||
print("in target") | |||
dataset = init_dataset() | |||
class FakeErrorTransform(Transform): | |||
def __init__(self): | |||
pass | |||
def apply(self, input): | |||
y = x + 1 | |||
return input | |||
dataloader = DataLoader( | |||
dataset, | |||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||
transform=FakeErrorTransform(), | |||
num_workers=2, | |||
preload=True, | |||
) | |||
with pytest.raises(RuntimeError, match=r"worker.*died"): | |||
data_iter = iter(dataloader) | |||
batch_data = next(data_iter) | |||
def _multi_instances_parallel_dataloader_worker(): | |||
dataset = init_dataset() | |||
for divide_flag in [True, False]: | |||
train_dataloader = DataLoader( | |||
dataset, | |||
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||
num_workers=2, | |||
divide=divide_flag, | |||
preload=True, | |||
) | |||
val_dataloader = DataLoader( | |||
dataset, | |||
sampler=RandomSampler(dataset, batch_size=10, drop_last=False), | |||
num_workers=2, | |||
divide=divide_flag, | |||
preload=True, | |||
) | |||
for idx, (data, label) in enumerate(train_dataloader): | |||
assert data._tuple_shape == (4, 1, 32, 32) | |||
assert label._tuple_shape == (4,) | |||
if idx % 5 == 0: | |||
for val_data, val_label in val_dataloader: | |||
assert val_data._tuple_shape == (10, 1, 32, 32) | |||
assert val_label._tuple_shape == (10,) | |||
def test_dataloader_parallel_multi_instances(): | |||
# set max shared memory to 100M | |||
os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||
_multi_instances_parallel_dataloader_worker() | |||
@pytest.mark.isolated_distributed | |||
def test_dataloader_parallel_multi_instances_multiprocessing(): | |||
gc.collect() | |||
# set max shared memory to 100M | |||
os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||
import multiprocessing as mp | |||
# mp.set_start_method("spawn") | |||
processes = [] | |||
for i in range(4): | |||
p = mp.Process(target=_multi_instances_parallel_dataloader_worker) | |||
p.start() | |||
processes.append(p) | |||
for p in processes: | |||
p.join() | |||
assert p.exitcode == 0 | |||
@pytest.mark.parametrize("num_workers", [0, 2]) | |||
def test_timeout_event(num_workers): | |||
def cb(): | |||
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,)))) | |||
dataset = MyStream(100, block=True) | |||
sampler = StreamSampler(batch_size=4) | |||
dataloader = DataLoader( | |||
dataset, | |||
sampler, | |||
num_workers=num_workers, | |||
timeout=2, | |||
timeout_event=cb, | |||
preload=True, | |||
) | |||
for _, data in enumerate(dataloader): | |||
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3))) | |||
np.testing.assert_equal(data[1], np.ones(shape=(4,))) | |||
break |