You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

persistent_cache.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import getpass
  11. import os
  12. import sys
  13. import urllib.parse
  14. from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager
  15. from ..logger import get_logger
  16. from ..version import __version__, git_version
  17. class PersistentCacheManager(_PersistentCacheManager):
  18. def __init__(self):
  19. super().__init__()
  20. if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY":
  21. get_logger().info("fastrun use in-memory cache")
  22. self.open_memory()
  23. elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE":
  24. self.open_file()
  25. else:
  26. self.open_redis()
  27. def open_memory(self):
  28. pass
  29. def open_file(self):
  30. cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR")
  31. try:
  32. if not cache_dir:
  33. from ..hub.hub import _get_megengine_home
  34. cache_dir = os.path.expanduser(
  35. os.path.join(_get_megengine_home(), "persistent_cache.bin")
  36. )
  37. os.makedirs(cache_dir, exist_ok=True)
  38. cache_file = os.path.join(cache_dir, "cache")
  39. with open(cache_file, "a"):
  40. pass
  41. assert self.try_open_file(cache_file), "cannot create file"
  42. get_logger().info("fastrun use in-file cache in {}".format(cache_dir))
  43. except Exception as exc:
  44. get_logger().error(
  45. "failed to create cache file in {} {!r}; fallback to "
  46. "in-memory cache".format(cache_dir, exc)
  47. )
  48. self.open_memory()
  49. def open_redis(self):
  50. prefix = "mgbcache:{}:MGB{}:GIT:{}".format(
  51. getpass.getuser(), __version__, git_version
  52. )
  53. url = os.getenv("MGE_FASTRUN_CACHE_URL")
  54. if url is None:
  55. self.open_file()
  56. try:
  57. assert sys.platform != "win32", "redis cache on windows not tested"
  58. parse_result = urllib.parse.urlparse(url, scheme="redis")
  59. assert parse_result.scheme == "redis", "unsupported scheme"
  60. assert not parse_result.username, "redis conn with username unsupported"
  61. assert self.try_open_redis(
  62. parse_result.hostname, parse_result.port, parse_result.password, prefix
  63. ), "connect failed"
  64. except Exception as exc:
  65. get_logger().error(
  66. "failed to connect to cache server {!r}; try fallback to "
  67. "in-file cache".format(exc)
  68. )
  69. self.open_file()
  70. _manager = None
  71. def get_manager():
  72. global _manager
  73. if _manager is None:
  74. _manager = PersistentCacheManager()
  75. return _manager
  76. def _clean():
  77. nr_del = get_manager().clean()
  78. if nr_del is not None:
  79. print("{} cache entries deleted".format(nr_del))
  80. def main():
  81. parser = argparse.ArgumentParser(description="manage persistent cache")
  82. subp = parser.add_subparsers(description="action to be performed", dest="cmd")
  83. subp.required = True
  84. subp_clean = subp.add_parser("clean", help="clean all the cache of current user")
  85. subp_clean.set_defaults(action=_clean)
  86. args = parser.parse_args()
  87. args.action()
  88. if __name__ == "__main__":
  89. main()