Browse Source

feat(mge/distributed): add preload host data with op

fix(mge/distributed): change api name with preload

fix(mge/distributed): fix recursive model in preload tensor

fix(mge/distributed): fix recursive when cache contain None

GitOrigin-RevId: 80e2a6dd70
release-1.6
Megvii Engine Team 3 years ago
parent
commit
dff7719e87
4 changed files with 436 additions and 29 deletions
  1. +111
    -23
      imperative/python/megengine/data/dataloader.py
  2. +13
    -0
      imperative/python/megengine/device.py
  3. +4
    -6
      imperative/python/megengine/distributed/group.py
  4. +308
    -0
      imperative/python/test/unit/data/test_pre_dataloader.py

+ 111
- 23
imperative/python/megengine/data/dataloader.py View File

@@ -15,12 +15,15 @@ import queue
import random
import threading
import time
from typing import Callable
from typing import Callable, Union

import numpy as np

from ..device import _sh, get_default_device
from ..functional.tensor import copy
from ..logger import get_logger
from ..random.rng import _random_seed_generator
from ..tensor import Tensor
from .collator import Collator
from .dataset import Dataset, StreamDataset
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
@@ -44,7 +47,7 @@ def raise_timeout_error():

class DataLoader:
r"""Provides a convenient way to iterate on a given dataset.
DataLoader combines a dataset with
:class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
make it flexible to get minibatch continually from a dataset.
@@ -66,6 +69,8 @@ class DataLoader:
``True`` means one batch is divided into :attr:`num_workers` pieces, and
the workers will process these pieces parallelly. ``False`` means
different sub-process will process different batch. Default: False
preload: Defines whether to apply the preloading strategy of dataloader, and parallelize the copy of host2device while kernal is executed to improve the loading speed. default is seted False
the output will change from np.ndarry to dtype tensor. the support dtypes for preload are int,float,list[int,float],tuple[int,float],and another type is not supported.
"""
__initialized = False

@@ -79,6 +84,7 @@ class DataLoader:
timeout: int = 0,
timeout_event: Callable = raise_timeout_error,
divide: bool = False,
preload: bool = False,
):
if num_workers < 0:
raise ValueError("num_workers should not be negative")
@@ -96,6 +102,7 @@ class DataLoader:
self.timeout_event = timeout_event

self.divide = divide
self.preload = preload

if isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
@@ -145,24 +152,74 @@ class DataLoader:
self.num_workers = 0
if isinstance(self.dataset, StreamDataset):
if not self.num_workers:
return _SerialStreamDataLoaderIter(self)
return _SerialStreamDataLoaderIter(self, self.preload)
else:
return _ParallelStreamDataLoaderIter(self)
return _ParallelStreamDataLoaderIter(self, self.preload)
else:
assert isinstance(
self.dataset, Dataset
), "Can not recognize this kind of dataset: %s" % type(self.dataset)
if not self.num_workers:
return _SerialMapDataLoaderIter(self)
return _SerialMapDataLoaderIter(self, self.preload)
else:
return _ParallelMapDataLoaderIter(self)
return _ParallelMapDataLoaderIter(self, self.preload)

def __len__(self):
return len(self.sampler)


class _BaseMapDataLoaderIter:
def __init__(self, loader):
class PreLoader:
def __init__(self, preload):
if preload:
self.default_device = get_default_device()
self.pre_load_device = self.default_device + ":" + str(_sh.get_next())
self.pre_load_device_cache = None
self.preload = preload

"""
strategy one: load from numpy data, and generate dtype tensor
"""

def _load_tensor(self, batch, cached=True):
if isinstance(batch, np.ndarray):
device = self.pre_load_device if cached else self.default_device
return Tensor(batch, device=device)
elif isinstance(batch, collections.abc.Mapping):
return {k: self._load_tensor(v, cached) for k, v in batch.items()}
elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple
return type(batch)(*(self._load_tensor(value, cached) for value in batch))
elif isinstance(batch, collections.abc.Sequence):
return [self._load_tensor(value, cached) for value in batch]
else:
return batch

"""
strategy two: load from cache that is already tensor just do d2d copy
"""

def _load_cache(self, data):
if isinstance(data, Tensor):
if data.device == self.default_device:
return data
return copy(data, device=self.default_device)
elif isinstance(data, collections.abc.Mapping):
return {k: self._load_cache(v) for k, v in data.items()}
elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
return type(data)(*(self._load_cache(value) for value in data))
elif isinstance(data, collections.abc.Sequence):
return [self._load_cache(value) for value in data]
else:
return data

def _swap_out_cache(self):
out = self._load_cache(self.pre_load_device_cache)
self.pre_load_device_cache = None # clean cache
return out


class _BaseMapDataLoaderIter(PreLoader):
def __init__(self, loader, preload):
super().__init__(preload)
self.dataset = loader.dataset
self.sampler = loader.sampler
self.seed = _random_seed_generator().__next__()
@@ -184,16 +241,35 @@ class _BaseMapDataLoaderIter:
return self

def __next__(self):
if self.preload:
cached = self.pre_load_device_cache
if cached is None: # first and last
if self.num_processed >= len(self): # last
raise StopIteration
elif self.num_processed == 0: # first
self._try_load_tensor(cached=False) # first do the h2d
out = self._swap_out_cache()
self._try_load_tensor()
return out
else:
if self.num_processed >= len(self):
raise StopIteration
minibatch = self._get_next_batch()
self.num_processed += 1
return minibatch

def _try_load_tensor(self, cached=True):
if self.num_processed >= len(self):
raise StopIteration
minibatch = self._get_next_batch()
self.num_processed += 1
return minibatch
return
else:
self.num_processed += 1
batch = self._get_next_batch()
self.pre_load_device_cache = self._load_tensor(batch, cached)


class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
def __init__(self, loader):
super(_SerialMapDataLoaderIter, self).__init__(loader)
def __init__(self, loader, preload):
super(_SerialMapDataLoaderIter, self).__init__(loader, preload)
self.indices_iter = iter(self.sampler)

def _get_next_batch(self):
@@ -206,8 +282,8 @@ class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
__initialized = False

def __init__(self, loader):
super(_ParallelMapDataLoaderIter, self).__init__(loader)
def __init__(self, loader, preload):
super(_ParallelMapDataLoaderIter, self).__init__(loader, preload)

self.task_queues = [
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
@@ -358,8 +434,9 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
self._shutdown()


class _BaseStreamDataLoaderIter:
def __init__(self, loader):
class _BaseStreamDataLoaderIter(PreLoader):
def __init__(self, loader, preload):
super().__init__(preload)
self.dataset = loader.dataset
self.sampler = loader.sampler
self.transform = loader.transform
@@ -388,12 +465,23 @@ class _BaseStreamDataLoaderIter:
return self

def __next__(self):
return self._get_next_batch()
if self.preload:
if self.pre_load_device_cache is None:
self._try_load_tensor(cached=False) # load in current
out = self._swap_out_cache()
self._try_load_tensor() # load in cached
return out
else:
return self._get_next_batch()

def _try_load_tensor(self, cached=True):
batch = self._get_next_batch()
self.pre_load_device_cache = self._load_tensor(batch, cached)


class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def __init__(self, loader):
super().__init__(loader)
def __init__(self, loader, preload):
super().__init__(loader, preload)
self.dataset_iter = iter(self.dataset)
self.idx = 0
self.unused = []
@@ -439,8 +527,8 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
__initialized = False

def __init__(self, loader):
super().__init__(loader)
def __init__(self, loader, preload):
super().__init__(loader, preload)

self.shutdown_flag = multiprocessing.Value("i", 0)



+ 13
- 0
imperative/python/megengine/device.py View File

@@ -29,6 +29,19 @@ __all__ = [
]


class _stream_helper:
def __init__(self):
self.stream = 1

def get_next(self):
out = self.stream
self.stream = self.stream + 1
return out


_sh = _stream_helper()


def _valid_device(inp):
if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp):
return True


+ 4
- 6
imperative/python/megengine/distributed/group.py View File

@@ -12,7 +12,7 @@ from typing import List, Optional, Tuple

from mprop import mproperty

from ..device import set_default_device, what_is_xpu
from ..device import _sh, set_default_device, what_is_xpu
from ..random import seed
from .server import Client, Server

@@ -27,7 +27,6 @@ class StaticData:
proc_rank = None
device = None
backend = None
next_stream = None
device_type = None
machine_ranks = None

@@ -43,6 +42,8 @@ class Group:

Args:
proc_ranks: rank list of the group, the first one is root rank.


"""

def __init__(self, proc_ranks):
@@ -55,9 +56,7 @@ class Group:
def reset(self, proc_ranks):
self.check(proc_ranks)
self.proc_ranks = proc_ranks
self.stream = _sd.next_stream
_sd.next_stream += 1
self.is_single_machine_cache = None
self.stream = _sh.get_next()

def check(self, proc_ranks):
assert _sd is not None, "please call init_process_group first"
@@ -160,7 +159,6 @@ def init_process_group(
_sd.proc_rank = rank
_sd.device = device
_sd.backend = backend
_sd.next_stream = 1
_sd.device_type = device_type

WORLD.reset(list(range(world_size)))


+ 308
- 0
imperative/python/test/unit/data/test_pre_dataloader.py View File

@@ -0,0 +1,308 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gc
import os
import platform
import time

import numpy as np
import pytest

from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader
from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import (
Compose,
Normalize,
PseudoTransform,
ToMode,
Transform,
)


def init_dataset():
sample_num = 100
rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8)
label = np.random.randint(0, 10, size=(sample_num,), dtype=int)
dataset = ArrayDataset(rand_data, label)
return dataset


def test_dataloader_init():
dataset = init_dataset()
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=2, divide=True)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, timeout=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=0, divide=True)

dataloader = DataLoader(dataset, preload=True)
assert isinstance(dataloader.sampler, SequentialSampler)
assert isinstance(dataloader.transform, PseudoTransform)
assert isinstance(dataloader.collator, Collator)

dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=6, drop_last=False),
preload=True,
)
assert len(dataloader) == 17
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=6, drop_last=True),
preload=True,
)
assert len(dataloader) == 16


class MyStream(StreamDataset):
def __init__(self, number, batch=False, error_foramt=False, block=False):
self.number = number
self.batch = batch
self.error_format = error_foramt
self.block = block

def __iter__(self):
for cnt in range(self.number):
if self.block:
for _ in range(10):
time.sleep(1)
if self.batch:
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
if self.error_format:
yield (data, cnt)
else:
yield (False, (data, cnt))
raise StopIteration


@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch=batch)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
num_workers=num_workers,
preload=True,
)

check_set = set()

for step, data in enumerate(dataloader):
if step == 10:
break
assert data[0]._tuple_shape == (4, 3, 2, 2)
assert data[1]._tuple_shape == (4,)
for i in data[1]:
assert i not in check_set
check_set.add(i)


def test_stream_dataloader_error():
dataset = MyStream(100, error_foramt=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler, preload=True)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)


@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False, block=True)
sampler = StreamSampler(batch_size=4)

dataloader = DataLoader(
dataset, sampler, num_workers=num_workers, timeout=2, preload=True
)
with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader)
next(data_iter)


def test_dataloader_serial():
dataset = init_dataset()
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
preload=True,
)
for (data, label) in dataloader:
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)


def test_dataloader_parallel():
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"

dataset = init_dataset()
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=False,
preload=True,
)
for (data, label) in dataloader:
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)

dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=True,
preload=True,
)
for (data, label) in dataloader:
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)


@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_dataloader_parallel_timeout():
dataset = init_dataset()

class TimeoutTransform(Transform):
def __init__(self):
pass

def apply(self, input):
time.sleep(10)
return input

dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
transform=TimeoutTransform(),
num_workers=2,
timeout=2,
preload=True,
)
with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader)
batch_data = next(data_iter)


@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_dataloader_parallel_worker_exception():
print("in target")
dataset = init_dataset()

class FakeErrorTransform(Transform):
def __init__(self):
pass

def apply(self, input):
y = x + 1
return input

dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
transform=FakeErrorTransform(),
num_workers=2,
preload=True,
)
with pytest.raises(RuntimeError, match=r"worker.*died"):
data_iter = iter(dataloader)
batch_data = next(data_iter)


def _multi_instances_parallel_dataloader_worker():
dataset = init_dataset()

for divide_flag in [True, False]:
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=divide_flag,
preload=True,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
divide=divide_flag,
preload=True,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)
if idx % 5 == 0:
for val_data, val_label in val_dataloader:
assert val_data._tuple_shape == (10, 1, 32, 32)
assert val_label._tuple_shape == (10,)


def test_dataloader_parallel_multi_instances():
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"

_multi_instances_parallel_dataloader_worker()


@pytest.mark.isolated_distributed
def test_dataloader_parallel_multi_instances_multiprocessing():
gc.collect()
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"

import multiprocessing as mp

# mp.set_start_method("spawn")
processes = []
for i in range(4):
p = mp.Process(target=_multi_instances_parallel_dataloader_worker)
p.start()
processes.append(p)

for p in processes:
p.join()
assert p.exitcode == 0


@pytest.mark.parametrize("num_workers", [0, 2])
def test_timeout_event(num_workers):
def cb():
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,))))

dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)

dataloader = DataLoader(
dataset,
sampler,
num_workers=num_workers,
timeout=2,
timeout_event=cb,
preload=True,
)
for _, data in enumerate(dataloader):
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
np.testing.assert_equal(data[1], np.ones(shape=(4,)))
break

Loading…
Cancel
Save