Browse Source

fix(mge/data/dataloader): fix typo, import and refine the logic of plasma_store

GitOrigin-RevId: 7d169a5294
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
afcda610f9
2 changed files with 28 additions and 7 deletions
  1. +25
    -5
      python_module/megengine/data/_queue.py
  2. +3
    -2
      python_module/megengine/data/dataloader.py

+ 25
- 5
python_module/megengine/data/_queue.py View File

@@ -16,6 +16,19 @@ import pyarrow.plasma as plasma


MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB


# Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER = None


def _clear_plasma_store():
# `_PlasmaStoreManager.__del__` will not ne called automaticly in subprocess,
# so this function should be called explicitly
global MGE_PLASMA_STORE_MANAGER
if MGE_PLASMA_STORE_MANAGER is not None:
del MGE_PLASMA_STORE_MANAGER
MGE_PLASMA_STORE_MANAGER = None



class _PlasmaStoreManager: class _PlasmaStoreManager:
def __init__(self): def __init__(self):
@@ -34,11 +47,6 @@ class _PlasmaStoreManager:
self.plasma_store.kill() self.plasma_store.kill()




# Each process only need to start one plasma store, so we set it as a global variable.
# TODO: how to share between different processes?
MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager()


class PlasmaShmQueue: class PlasmaShmQueue:
def __init__(self, maxsize: int = 0): def __init__(self, maxsize: int = 0):
r"""Use pyarrow in-memory plasma store to implement shared memory queue. r"""Use pyarrow in-memory plasma store to implement shared memory queue.
@@ -51,6 +59,17 @@ class PlasmaShmQueue:
:param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``) :param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``)
""" """


# Lazy start the plasma store manager
global MGE_PLASMA_STORE_MANAGER
if MGE_PLASMA_STORE_MANAGER is None:
try:
MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager()
except FileNotFoundError as e:
raise FileNotFoundError(
"command 'plasma_store' not found in your $PATH!"
"Please make sure pyarrow installed and add into $PATH."
)

self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name


# TODO: how to catch the exception happened in `plasma.connect`? # TODO: how to catch the exception happened in `plasma.connect`?
@@ -100,6 +119,7 @@ class PlasmaShmQueue:
def close(self): def close(self):
self.queue.close() self.queue.close()
self.disconnect_client() self.disconnect_client()
_clear_plasma_store()


def cancel_join_thread(self): def cancel_join_thread(self):
self.queue.cancel_join_thread() self.queue.cancel_join_thread()

+ 3
- 2
python_module/megengine/data/dataloader.py View File

@@ -17,12 +17,13 @@ import numpy as np


import megengine as mge import megengine as mge


from ..logger import get_logger
from .collator import Collator from .collator import Collator
from .dataset import Dataset from .dataset import Dataset
from .sampler import Sampler, SequentialSampler from .sampler import Sampler, SequentialSampler
from .transform import PseudoTransform, Transform from .transform import PseudoTransform, Transform


logger = mge.get_logger(__name__)
logger = get_logger(__name__)




MP_QUEUE_GET_TIMEOUT = 5 MP_QUEUE_GET_TIMEOUT = 5
@@ -167,7 +168,7 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter):




class _ParallelDataLoaderIter(_BaseDataLoaderIter): class _ParallelDataLoaderIter(_BaseDataLoaderIter):
__initialzed = False
__initialized = False


def __init__(self, loader): def __init__(self, loader):
super(_ParallelDataLoaderIter, self).__init__(loader) super(_ParallelDataLoaderIter, self).__init__(loader)


Loading…
Cancel
Save