|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-
- import argparse
- import getpass
- import os
- import sys
- import urllib.parse
-
- from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager
- from ..logger import get_logger
- from ..version import __version__, git_version
-
-
- class PersistentCacheManager(_PersistentCacheManager):
- def __init__(self):
- super().__init__()
- if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY":
- get_logger().info("fastrun use in-memory cache")
- self.open_memory()
- elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE":
- self.open_file()
- else:
- self.open_redis()
-
- def open_memory(self):
- pass
-
- def open_file(self):
- cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR")
- try:
- if not cache_dir:
- from ..hub.hub import _get_megengine_home
-
- cache_dir = os.path.expanduser(
- os.path.join(_get_megengine_home(), "persistent_cache.bin")
- )
- os.makedirs(cache_dir, exist_ok=True)
- cache_file = os.path.join(cache_dir, "cache")
- with open(cache_file, "a"):
- pass
- assert self.try_open_file(cache_file), "cannot create file"
- get_logger().info("fastrun use in-file cache in {}".format(cache_dir))
- except Exception as exc:
- get_logger().error(
- "failed to create cache file in {} {!r}; fallback to "
- "in-memory cache".format(cache_dir, exc)
- )
- self.open_memory()
-
- def open_redis(self):
- prefix = "mgbcache:{}:MGB{}:GIT:{}".format(
- getpass.getuser(), __version__, git_version
- )
- url = os.getenv("MGE_FASTRUN_CACHE_URL")
- if url is None:
- self.open_file()
- try:
- assert sys.platform != "win32", "redis cache on windows not tested"
- parse_result = urllib.parse.urlparse(url, scheme="redis")
- assert parse_result.scheme == "redis", "unsupported scheme"
- assert not parse_result.username, "redis conn with username unsupported"
- assert self.try_open_redis(
- parse_result.hostname, parse_result.port, parse_result.password, prefix
- ), "connect failed"
- except Exception as exc:
- get_logger().error(
- "failed to connect to cache server {!r}; try fallback to "
- "in-file cache".format(exc)
- )
- self.open_file()
-
-
- _manager = None
-
-
- def get_manager():
- global _manager
- if _manager is None:
- _manager = PersistentCacheManager()
- return _manager
-
-
- def _clean():
- nr_del = get_manager().clean()
- if nr_del is not None:
- print("{} cache entries deleted".format(nr_del))
-
-
- def main():
- parser = argparse.ArgumentParser(description="manage persistent cache")
- subp = parser.add_subparsers(description="action to be performed", dest="cmd")
- subp.required = True
- subp_clean = subp.add_parser("clean", help="clean all the cache of current user")
- subp_clean.set_defaults(action=_clean)
- args = parser.parse_args()
- args.action()
-
-
- if __name__ == "__main__":
- main()
|