# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools import hashlib import os import sys import types from typing import Any, List from urllib.parse import urlparse from megengine.utils.http_download import download_from_url from ..distributed import is_distributed from ..logger import get_logger from ..serialization import load as _mge_load_serialized from .const import ( DEFAULT_CACHE_DIR, DEFAULT_GIT_HOST, DEFAULT_PROTOCOL, ENV_MGE_HOME, ENV_XDG_CACHE_HOME, HTTP_READ_TIMEOUT, HUBCONF, HUBDEPENDENCY, ) from .exceptions import InvalidProtocol from .fetcher import GitHTTPSFetcher, GitSSHFetcher from .tools import cd, check_module_exists, load_module logger = get_logger(__name__) PROTOCOLS = { "HTTPS": GitHTTPSFetcher, "SSH": GitSSHFetcher, } def _get_megengine_home() -> str: """MGE_HOME setting complies with the XDG Base Directory Specification """ megengine_home = os.path.expanduser( os.getenv( ENV_MGE_HOME, os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"), ) ) return megengine_home def _get_repo( git_host: str, repo_info: str, use_cache: bool = False, commit: str = None, protocol: str = DEFAULT_PROTOCOL, ) -> str: if protocol not in PROTOCOLS: raise InvalidProtocol( "Invalid protocol, the value should be one of {}.".format( ", ".join(PROTOCOLS.keys()) ) ) cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) with cd(cache_dir): fetcher = PROTOCOLS[protocol] repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit) return os.path.join(cache_dir, repo_dir) def _check_dependencies(module: types.ModuleType) -> None: if not hasattr(module, HUBDEPENDENCY): return dependencies = getattr(module, HUBDEPENDENCY) if not dependencies: return missing_deps = [m for m in dependencies if not check_module_exists(m)] if len(missing_deps): raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) def _init_hub( repo_info: str, git_host: str, use_cache: bool = True, commit: str = None, protocol: str = DEFAULT_PROTOCOL, ): """Imports hubmodule like python import. :param repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional tag/branch. The default branch is ``master`` if not specified. Example: ``"brain_sdk/MegBrain[:hub]"`` :param git_host: host address of git repo. Example: github.com :param use_cache: whether to use locally cached code or completely re-fetch. :param commit: commit id on github or gitlab. :param protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. The value should be one of HTTPS, SSH. :return: a python module. """ cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub")) os.makedirs(cache_dir, exist_ok=True) absolute_repo_dir = _get_repo( git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol ) sys.path.insert(0, absolute_repo_dir) hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF)) sys.path.remove(absolute_repo_dir) return hubmodule @functools.wraps(_init_hub) def import_module(*args, **kwargs): return _init_hub(*args, **kwargs) def list( repo_info: str, git_host: str = DEFAULT_GIT_HOST, use_cache: bool = True, commit: str = None, protocol: str = DEFAULT_PROTOCOL, ) -> List[str]: """Lists all entrypoints available in repo hubconf. :param repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional tag/branch. The default branch is ``master`` if not specified. Example: ``"brain_sdk/MegBrain[:hub]"`` :param git_host: host address of git repo. Example: github.com :param use_cache: whether to use locally cached code or completely re-fetch. :param commit: commit id on github or gitlab. :param protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. The value should be one of HTTPS, SSH. :return: all entrypoint names of the model. """ hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) return [ _ for _ in dir(hubmodule) if not _.startswith("__") and callable(getattr(hubmodule, _)) ] def load( repo_info: str, entry: str, *args, git_host: str = DEFAULT_GIT_HOST, use_cache: bool = True, commit: str = None, protocol: str = DEFAULT_PROTOCOL, **kwargs ) -> Any: """Loads model from github or gitlab repo, with pretrained weights. :param repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional tag/branch. The default branch is ``master`` if not specified. Example: ``"brain_sdk/MegBrain[:hub]"`` :param entry: an entrypoint defined in hubconf. :param git_host: host address of git repo. Example: github.com :param use_cache: whether to use locally cached code or completely re-fetch. :param commit: commit id on github or gitlab. :param protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. The value should be one of HTTPS, SSH. :return: a single model with corresponding pretrained weights. """ hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) _check_dependencies(hubmodule) module = getattr(hubmodule, entry)(*args, **kwargs) return module def help( repo_info: str, entry: str, git_host: str = DEFAULT_GIT_HOST, use_cache: bool = True, commit: str = None, protocol: str = DEFAULT_PROTOCOL, ) -> str: """This function returns docstring of entrypoint ``entry`` by following steps: 1. Pull the repo code specified by git and repo_info. 2. Load the entry defined in repo's hubconf.py 3. Return docstring of function entry. :param repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional tag/branch. The default branch is ``master`` if not specified. Example: ``"brain_sdk/MegBrain[:hub]"`` :param entry: an entrypoint defined in hubconf.py :param git_host: host address of git repo. Example: github.com :param use_cache: whether to use locally cached code or completely re-fetch. :param commit: commit id on github or gitlab. :param protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. The value should be one of HTTPS, SSH. :return: docstring of entrypoint ``entry``. """ hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) doc = getattr(hubmodule, entry).__doc__ return doc def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: """Loads MegEngine serialized object from the given URL. If the object is already present in ``model_dir``, it's deserialized and returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``. :param url: url to serialized object. :param model_dir: dir to cache target serialized file. :return: loaded object. """ if model_dir is None: model_dir = os.path.join(_get_megengine_home(), "serialized") os.makedirs(model_dir, exist_ok=True) parts = urlparse(url) filename = os.path.basename(parts.path) # use hash as prefix to avoid filename conflict from different urls sha256 = hashlib.sha256() sha256.update(url.encode()) digest = sha256.hexdigest()[:6] filename = digest + "_" + filename cached_file = os.path.join(model_dir, filename) logger.info( "load_serialized_obj_from_url: download to or using cached %s", cached_file ) if not os.path.exists(cached_file): if is_distributed(): logger.warning( "Downloading serialized object in DISTRIBUTED mode\n" " File may be downloaded multiple times. We recommend\n" " users to download in single process first." ) download_from_url(url, cached_file, HTTP_READ_TIMEOUT) state_dict = _mge_load_serialized(cached_file) return state_dict class pretrained: r""" Decorator which helps to download pretrained weights from the given url. For example, we can decorate a resnet18 function as follows .. code-block:: @hub.pretrained("https://url/to/pretrained_resnet18.pkl") def resnet18(**kwargs): return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) When decorated function is called with ``pretrained=True``, MegEngine will automatically download and fill the returned model with pretrained weights. """ def __init__(self, url): self.url = url def __call__(self, func): @functools.wraps(func) def pretrained_model_func( pretrained=False, **kwargs ): # pylint: disable=redefined-outer-name model = func(**kwargs) if pretrained: weights = load_serialized_obj_from_url(self.url) model.load_state_dict(weights) return model return pretrained_model_func __all__ = [ "list", "load", "help", "load_serialized_obj_from_url", "pretrained", "import_module", ]