You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_queue.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import binascii
  10. import os
  11. import queue
  12. import subprocess
  13. from multiprocessing import Queue
  14. import pyarrow.plasma as plasma
  15. MGE_PLASMA_MEMORY = int(os.environ.get("MGE_PLASMA_MEMORY", 4000000000)) # 4GB
  16. # Each process only need to start one plasma store, so we set it as a global variable.
  17. # TODO: how to share between different processes?
  18. MGE_PLASMA_STORE_MANAGER = None
  19. def _clear_plasma_store():
  20. # `_PlasmaStoreManager.__del__` will not ne called automaticly in subprocess,
  21. # so this function should be called explicitly
  22. global MGE_PLASMA_STORE_MANAGER
  23. if MGE_PLASMA_STORE_MANAGER is not None:
  24. del MGE_PLASMA_STORE_MANAGER
  25. MGE_PLASMA_STORE_MANAGER = None
  26. class _PlasmaStoreManager:
  27. def __init__(self):
  28. self.socket_name = "/tmp/mge_plasma_{}".format(
  29. binascii.hexlify(os.urandom(8)).decode()
  30. )
  31. debug_flag = bool(os.environ.get("MGE_DATALOADER_PLASMA_DEBUG", 0))
  32. self.plasma_store = subprocess.Popen(
  33. ["plasma_store", "-s", self.socket_name, "-m", str(MGE_PLASMA_MEMORY),],
  34. stdout=None if debug_flag else subprocess.DEVNULL,
  35. stderr=None if debug_flag else subprocess.DEVNULL,
  36. )
  37. def __del__(self):
  38. if self.plasma_store and self.plasma_store.returncode is None:
  39. self.plasma_store.kill()
  40. class PlasmaShmQueue:
  41. def __init__(self, maxsize: int = 0):
  42. r"""Use pyarrow in-memory plasma store to implement shared memory queue.
  43. Compared to native `multiprocess.Queue`, `PlasmaShmQueue` avoid pickle/unpickle
  44. and communication overhead, leading to better performance in multi-process
  45. application.
  46. :type maxsize: int
  47. :param maxsize: maximum size of the queue, `None` means no limit. (default: ``None``)
  48. """
  49. # Lazy start the plasma store manager
  50. global MGE_PLASMA_STORE_MANAGER
  51. if MGE_PLASMA_STORE_MANAGER is None:
  52. try:
  53. MGE_PLASMA_STORE_MANAGER = _PlasmaStoreManager()
  54. except FileNotFoundError as e:
  55. raise FileNotFoundError(
  56. "command 'plasma_store' not found in your $PATH!"
  57. "Please make sure pyarrow installed and add into $PATH."
  58. )
  59. self.socket_name = MGE_PLASMA_STORE_MANAGER.socket_name
  60. # TODO: how to catch the exception happened in `plasma.connect`?
  61. self.client = None
  62. # Used to store the header for the data.(ObjectIDs)
  63. self.queue = Queue(maxsize) # type: Queue
  64. def put(self, data, block=True, timeout=None):
  65. if self.client is None:
  66. self.client = plasma.connect(self.socket_name)
  67. try:
  68. object_id = self.client.put(data)
  69. except plasma.PlasmaStoreFull:
  70. raise RuntimeError("plasma store out of memory!")
  71. try:
  72. self.queue.put(object_id, block, timeout)
  73. except queue.Full:
  74. self.client.delete([object_id])
  75. raise queue.Full
  76. def get(self, block=True, timeout=None):
  77. if self.client is None:
  78. self.client = plasma.connect(self.socket_name)
  79. object_id = self.queue.get(block, timeout)
  80. if not self.client.contains(object_id):
  81. raise RuntimeError(
  82. "ObjectID: {} not found in plasma store".format(object_id)
  83. )
  84. data = self.client.get(object_id)
  85. self.client.delete([object_id])
  86. return data
  87. def qsize(self):
  88. return self.queue.qsize()
  89. def empty(self):
  90. return self.queue.empty()
  91. def join(self):
  92. self.queue.join()
  93. def disconnect_client(self):
  94. if self.client is not None:
  95. self.client.disconnect()
  96. def close(self):
  97. self.queue.close()
  98. self.disconnect_client()
  99. _clear_plasma_store()
  100. def cancel_join_thread(self):
  101. self.queue.cancel_join_thread()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)