You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

persistent_cache.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import contextlib
  11. import getpass
  12. import os
  13. import sys
  14. import urllib.parse
  15. import filelock
  16. from ..core._imperative_rt import PersistentCache as _PersistentCache
  17. from ..logger import get_logger
  18. from ..version import __version__, git_version
  19. class PersistentCacheOnServer(_PersistentCache):
  20. def __init__(self):
  21. super().__init__()
  22. cache_type = os.getenv("MGE_FASTRUN_CACHE_TYPE")
  23. if cache_type not in ("FILE", "MEMORY"):
  24. try:
  25. redis_config = self.get_redis_config()
  26. except Exception as exc:
  27. get_logger().error(
  28. "failed to connect to cache server {!r}; try fallback to "
  29. "in-file cache".format(exc)
  30. )
  31. else:
  32. self.add_config(
  33. "redis",
  34. redis_config,
  35. "fastrun use redis cache",
  36. "failed to connect to cache server",
  37. )
  38. if cache_type != "MEMORY":
  39. path = self.get_cache_file(self.get_cache_dir())
  40. self.add_config(
  41. "in-file",
  42. {"path": path},
  43. "fastrun use in-file cache in {}".format(path),
  44. "failed to create cache file in {}".format(path),
  45. )
  46. self.add_config(
  47. "in-memory",
  48. {},
  49. "fastrun use in-memory cache",
  50. "failed to create in-memory cache",
  51. )
  52. def get_cache_dir(self):
  53. cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR")
  54. if not cache_dir:
  55. from ..hub.hub import _get_megengine_home
  56. cache_dir = os.path.expanduser(
  57. os.path.join(_get_megengine_home(), "persistent_cache")
  58. )
  59. os.makedirs(cache_dir, exist_ok=True)
  60. return cache_dir
  61. def get_cache_file(self, cache_dir):
  62. cache_file = os.path.join(cache_dir, "cache.bin")
  63. with open(cache_file, "a"):
  64. pass
  65. return cache_file
  66. @contextlib.contextmanager
  67. def lock_cache_file(self, cache_dir):
  68. lock_file = os.path.join(cache_dir, "cache.lock")
  69. with filelock.FileLock(lock_file):
  70. yield
  71. def get_redis_config(self):
  72. url = os.getenv("MGE_FASTRUN_CACHE_URL")
  73. if url is None:
  74. return None
  75. assert sys.platform != "win32", "redis cache on windows not tested"
  76. prefix = "mgbcache:{}:MGB{}:GIT:{}".format(
  77. getpass.getuser(), __version__, git_version
  78. )
  79. parse_result = urllib.parse.urlparse(url)
  80. assert not parse_result.username, "redis conn with username unsupported"
  81. if parse_result.scheme == "redis":
  82. assert parse_result.hostname and parse_result.port, "invalid url"
  83. assert not parse_result.path
  84. config = {
  85. "hostname": parse_result.hostname,
  86. "port": str(parse_result.port),
  87. }
  88. elif parse_result.scheme == "redis+socket":
  89. assert not (parse_result.hostname or parse_result.port)
  90. assert parse_result.path
  91. config = {
  92. "unixsocket": parse_result.path,
  93. }
  94. else:
  95. assert False, "unsupported scheme"
  96. if parse_result.password is not None:
  97. config["password"] = parse_result.password
  98. config["prefix"] = prefix
  99. return config
  100. def flush(self):
  101. if self.config is not None and self.config.type == "in-file":
  102. with self.lock_cache_file(self.get_cache_dir()):
  103. super().flush()
  104. def _clean():
  105. nr_del = PersistentCacheOnServer().clean()
  106. if nr_del is not None:
  107. print("{} cache entries deleted".format(nr_del))
  108. def main():
  109. parser = argparse.ArgumentParser(description="manage persistent cache")
  110. subp = parser.add_subparsers(description="action to be performed", dest="cmd")
  111. subp.required = True
  112. subp_clean = subp.add_parser("clean", help="clean all the cache of current user")
  113. subp_clean.set_defaults(action=_clean)
  114. args = parser.parse_args()
  115. args.action()
  116. if __name__ == "__main__":
  117. main()