You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

persistent_cache.py 2.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import getpass
  11. import os
  12. import sys
  13. from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager
  14. from ..logger import get_logger
  15. from ..version import __version__, git_version
  16. class PersistentCacheManager(_PersistentCacheManager):
  17. def __init__(self):
  18. super().__init__()
  19. if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY":
  20. get_logger().info("fastrun use in-memory cache")
  21. self.open_memory()
  22. else:
  23. self.open_file()
  24. def open_memory(self):
  25. pass
  26. def open_file(self):
  27. cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR")
  28. try:
  29. if not cache_dir:
  30. from ..hub.hub import _get_megengine_home
  31. cache_dir = os.path.expanduser(
  32. os.path.join(_get_megengine_home(), "persistent_cache.bin")
  33. )
  34. os.makedirs(cache_dir, exist_ok=True)
  35. cache_file = os.path.join(cache_dir, "cache")
  36. with open(cache_file, "a"):
  37. pass
  38. assert self.try_open_file(cache_file), "cannot create file"
  39. get_logger().info("fastrun use in-file cache in {}".format(cache_dir))
  40. except Exception as exc:
  41. get_logger().error(
  42. "failed to create cache file in {} {!r}; fallback to "
  43. "in-memory cache".format(cache_dir, exc)
  44. )
  45. self.open_memory()
  46. _manager = None
  47. def get_manager():
  48. global _manager
  49. if _manager is None:
  50. _manager = PersistentCacheManager()
  51. return _manager