You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

persistent_cache.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import getpass
  11. import json
  12. import os
  13. import shelve
  14. from .logconf import get_logger
  15. from .mgb import _PersistentCache
  16. from .version import __version__
  17. class _FakeRedisConn:
  18. def __init__(self):
  19. try:
  20. from ..hub.hub import _get_megengine_home
  21. cache_dir = os.path.expanduser(
  22. os.path.join(_get_megengine_home(), "persistent_cache")
  23. )
  24. os.makedirs(cache_dir, exist_ok=True)
  25. cache_file = os.path.join(cache_dir, "cache")
  26. self._dict = shelve.open(cache_file)
  27. self._is_shelve = True
  28. except:
  29. self._dict = {}
  30. self._is_shelve = False
  31. def get(self, key):
  32. if self._is_shelve and isinstance(key, bytes):
  33. key = key.decode("utf-8")
  34. return self._dict.get(key)
  35. def set(self, key, val):
  36. if self._is_shelve and isinstance(key, bytes):
  37. key = key.decode("utf-8")
  38. self._dict[key] = val
  39. def __del__(self):
  40. if self._is_shelve:
  41. self._dict.close()
  42. class PersistentCacheOnServer(_PersistentCache):
  43. _cached_conn = None
  44. _prefix = None
  45. _prev_get_refkeep = None
  46. @property
  47. def _conn(self):
  48. """get redis connection"""
  49. if self._cached_conn is None:
  50. try:
  51. self._cached_conn = self.make_redis_conn()
  52. except Exception as exc:
  53. get_logger().error(
  54. "failed to connect to cache server: {!r}; fallback to "
  55. "in-memory cache".format(exc)
  56. )
  57. self._cached_conn = _FakeRedisConn()
  58. self._prefix = self.make_user_prefix()
  59. return self._cached_conn
  60. @classmethod
  61. def make_user_prefix(cls):
  62. return "mgbcache:{}".format(getpass.getuser())
  63. @classmethod
  64. def make_redis_conn(cls):
  65. import redis
  66. conn = redis.StrictRedis(
  67. 'localhost', 6381,
  68. socket_connect_timeout=2, socket_timeout=1)
  69. return conn
  70. def _make_key(self, category, key):
  71. prefix_with_version = "{}:MGB{}".format(self._prefix, __version__)
  72. return b"@".join(
  73. (prefix_with_version.encode("ascii"), category.encode("ascii"), key)
  74. )
  75. def put(self, category, key, value):
  76. conn = self._conn
  77. key = self._make_key(category, key)
  78. conn.set(key, value)
  79. def get(self, category, key):
  80. conn = self._conn
  81. key = self._make_key(category, key)
  82. self._prev_get_refkeep = conn.get(key)
  83. return self._prev_get_refkeep
  84. def _clean():
  85. match = PersistentCacheOnServer.make_user_prefix() + "*"
  86. conn = PersistentCacheOnServer.make_redis_conn()
  87. cursor = 0
  88. nr_del = 0
  89. while True:
  90. cursor, values = conn.scan(cursor, match)
  91. if values:
  92. conn.delete(*values)
  93. nr_del += len(values)
  94. if not cursor:
  95. break
  96. print("{} cache entries deleted".format(nr_del))
  97. def main():
  98. parser = argparse.ArgumentParser(description="manage persistent cache")
  99. subp = parser.add_subparsers(description="action to be performed", dest="cmd")
  100. subp.required = True
  101. subp_clean = subp.add_parser("clean", help="clean all the cache of current user")
  102. subp_clean.set_defaults(action=_clean)
  103. args = parser.parse_args()
  104. args.action()
  105. if __name__ == "__main__":
  106. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)