|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import binascii
- import os
- import queue
- import subprocess
- from multiprocessing import Queue
-
- import pyarrow.plasma as plasma
-
- MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB
-
- # Each process only need to start one plasma store, so we set it as a global variable.
- # TODO: how to share between different processes?
- MGE_PLASMA_STORE_MANAGER = None
-
-
- def _clear_plasma_store():
- # `_PlasmaStoreManager.__del__` will not ne called automaticly in subprocess,
- # so this function should be called explicitly
- global MGE_PLASMA_STORE_MANAGER
- if MGE_PLASMA_STORE_MANAGER is not None:
- del MGE_PLASMA_STORE_MANAGER
- MGE_PLASMA_STORE_MANAGER = None
-
-
- class _PlasmaStoreManager:
- def __init__(self):
- self.socket_name = "/tmp/mge_plasma_{}".format(
- binascii.hexlify(os.urandom(8)).decode()
- )
- debug_flag = bool(os.environ.get("MGE_DATALOADER_PLASMA_DEBUG", 0))
- self.plasma_store = subprocess.Popen(
- ["plasma_store", "-s", self.socket_name, "-m", str(MGE_PLASMA_MEMORY),],
- stdout=None if debug_flag else subprocess.DEVNULL,
- stderr=None if debug_flag else subprocess.DEVNULL,
- )
-
- def __del__(self):
- if self.plasma_store and self.plasma_store.returncode is None:
- self.plasma_store.kill()
-
-
- class PlasmaShmQueue:
- def __init__(self, maxsize: int = 0):
- r"""Use pyarrow in-memory plasma store to implement shared memory queue.
-
- Compared to native `multiprocess.Queue`, `PlasmaShmQueue` avoid pickle/unpickle
- and communication overhead, leading to better performance in multi-process
- application.
-
- :type maxsize: int
- :param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``)
- """
-
- # Lazy start the plasma store manager
- global MGE_PLASMA_STORE_MANAGER
- if MGE_PLASMA_STORE_MANAGER is None:
- try:
- MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager()
- except FileNotFoundError as e:
- raise FileNotFoundError(
- "command 'plasma_store' not found in your $PATH!"
- "Please make sure pyarrow installed and add into $PATH."
- )
-
- self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name
-
- # TODO: how to catch the exception happened in `plasma.connect`?
- self.client = None
-
- # Used to store the header for the data.(ObjectIDs)
- self.queue = Queue(maxsize) # type: Queue
-
- def put(self, data, block=True, timeout=None):
- if self.client is None:
- self.client = plasma.connect(self.socket_name)
- try:
- object_id = self.client.put(data)
- except plasma.PlasmaStoreFull:
- raise RuntimeError("plasma store out of memory!")
- try:
- self.queue.put(object_id, block, timeout)
- except queue.Full:
- self.client.delete([object_id])
- raise queue.Full
-
- def get(self, block=True, timeout=None):
- if self.client is None:
- self.client = plasma.connect(self.socket_name)
- object_id = self.queue.get(block, timeout)
- if not self.client.contains(object_id):
- raise RuntimeError(
- "ObjectID: {} not found in plasma store".format(object_id)
- )
- data = self.client.get(object_id)
- self.client.delete([object_id])
- return data
-
- def qsize(self):
- return self.queue.qsize()
-
- def empty(self):
- return self.queue.empty()
-
- def join(self):
- self.queue.join()
-
- def disconnect_client(self):
- if self.client is not None:
- self.client.disconnect()
-
- def close(self):
- self.queue.close()
- self.disconnect_client()
- _clear_plasma_store()
-
- def cancel_join_thread(self):
- self.queue.cancel_join_thread()
|