Browse Source

fix(mge/data/dataloader): add refcount in _PlasmaStoreManager

GitOrigin-RevId: 9a95cb1a5d
release-0.5
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
c8a9094bb3
2 changed files with 55 additions and 1 deletions
  1. +6
    -1
      python_module/megengine/data/_queue.py
  2. +49
    -0
      python_module/test/unit/data/test_dataloader.py

+ 6
- 1
python_module/megengine/data/_queue.py View File

@@ -26,7 +26,7 @@ def _clear_plasma_store():
# `_PlasmaStoreManager.__del__` will not be called automaticly in subprocess,
# so this function should be called explicitly
global MGE_PLASMA_STORE_MANAGER
if MGE_PLASMA_STORE_MANAGER is not None:
if MGE_PLASMA_STORE_MANAGER is not None and MGE_PLASMA_STORE_MANAGER.refcount == 0:
del MGE_PLASMA_STORE_MANAGER
MGE_PLASMA_STORE_MANAGER = None

@@ -50,6 +50,7 @@ class _PlasmaStoreManager:
stderr=None if debug_flag else subprocess.DEVNULL,
)
self.__initialized = True
self.refcount = 1

def __del__(self):
if self.__initialized and self.plasma_store.returncode is None:
@@ -83,6 +84,8 @@ class PlasmaShmQueue:
"Exception happened in starting plasma_store: {}\n"
"Tips: {}".format(str(e), err_info)
)
else:
MGE_PLASMA_STORE_MANAGER.refcount += 1

self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name

@@ -133,6 +136,8 @@ class PlasmaShmQueue:
def close(self):
self.queue.close()
self.disconnect_client()
global MGE_PLASMA_STORE_MANAGER
MGE_PLASMA_STORE_MANAGER.refcount -= 1
_clear_plasma_store()

def cancel_join_thread(self):


+ 49
- 0
python_module/test/unit/data/test_dataloader.py View File

@@ -132,3 +132,52 @@ def test_dataloader_parallel_worker_exception():
with pytest.raises(RuntimeError, match=r"worker.*died"):
data_iter = iter(dataloader)
batch_data = next(data_iter)


def _multi_instances_parallel_dataloader_worker():
dataset = init_dataset()

for divide_flag in [True, False]:
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=divide_flag,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
divide=divide_flag,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data.shape == (4, 1, 32, 32)
assert label.shape == (4,)
if idx % 5 == 0:
for val_data, val_label in val_dataloader:
assert val_data.shape == (10, 1, 32, 32)
assert val_label.shape == (10,)


def test_dataloader_parallel_multi_instances():
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"

_multi_instances_parallel_dataloader_worker()


def test_dataloader_parallel_multi_instances_multiprocessing():
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"

import multiprocessing as mp

# mp.set_start_method("spawn")
processes = []
for i in range(4):
p = mp.Process(target=_multi_instances_parallel_dataloader_worker)
p.start()
processes.append(p)

for p in processes:
p.join()

Loading…
Cancel
Save