Browse Source

fix(mge/data/dataloader): fix typo, import and refine the logic of plasma_store

GitOrigin-RevId: 7d169a5294
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
afcda610f9
2 changed files with 28 additions and 7 deletions
  1. +25
    -5
      python_module/megengine/data/_queue.py
  2. +3
    -2
      python_module/megengine/data/dataloader.py

+ 25
- 5
python_module/megengine/data/_queue.py View File

@@ -16,6 +16,19 @@ import pyarrow.plasma as plasma

MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB

# Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER = None


def _clear_plasma_store():
# `_PlasmaStoreManager.__del__` will not ne called automaticly in subprocess,
# so this function should be called explicitly
global MGE_PLASMA_STORE_MANAGER
if MGE_PLASMA_STORE_MANAGER is not None:
del MGE_PLASMA_STORE_MANAGER
MGE_PLASMA_STORE_MANAGER = None


class _PlasmaStoreManager:
def __init__(self):
@@ -34,11 +47,6 @@ class _PlasmaStoreManager:
self.plasma_store.kill()


# Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager()


class PlasmaShmQueue:
def __init__(self, maxsize: int = 0):
r"""Use pyarrow in-memory plasma store to implement shared memory queue.
@@ -51,6 +59,17 @@ class PlasmaShmQueue:
:param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``)
"""

# Lazy start the plasma store manager
global MGE_PLASMA_STORE_MANAGER
if MGE_PLASMA_STORE_MANAGER is None:
try:
MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager()
except FileNotFoundError as e:
raise FileNotFoundError(
"command 'plasma_store' not found in your $PATH!"
"Please make sure pyarrow installed and add into $PATH."
)

self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name

# TODO: how to catch the exception happened in `plasma.connect`?
@@ -100,6 +119,7 @@ class PlasmaShmQueue:
def close(self):
self.queue.close()
self.disconnect_client()
_clear_plasma_store()

def cancel_join_thread(self):
self.queue.cancel_join_thread()

+ 3
- 2
python_module/megengine/data/dataloader.py View File

@@ -17,12 +17,13 @@ import numpy as np

import megengine as mge

from ..logger import get_logger
from .collator import Collator
from .dataset import Dataset
from .sampler import Sampler, SequentialSampler
from .transform import PseudoTransform, Transform

logger = mge.get_logger(__name__)
logger = get_logger(__name__)


MP_QUEUE_GET_TIMEOUT = 5
@@ -167,7 +168,7 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter):


class _ParallelDataLoaderIter(_BaseDataLoaderIter):
__initialzed = False
__initialized = False

def __init__(self, loader):
super(_ParallelDataLoaderIter, self).__init__(loader)


Loading…
Cancel
Save