|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import functools
- import hashlib
- import os
- import sys
- import types
- from typing import Any, List
- from urllib.parse import urlparse
-
- from megengine.utils.http_download import download_from_url
-
- from ..distributed import is_distributed
- from ..logger import get_logger
- from ..serialization import load as _mge_load_serialized
- from .const import (
- DEFAULT_CACHE_DIR,
- DEFAULT_GIT_HOST,
- DEFAULT_PROTOCOL,
- ENV_MGE_HOME,
- ENV_XDG_CACHE_HOME,
- HTTP_READ_TIMEOUT,
- HUBCONF,
- HUBDEPENDENCY,
- )
- from .exceptions import InvalidProtocol
- from .fetcher import GitHTTPSFetcher, GitSSHFetcher
- from .tools import cd, check_module_exists, load_module
-
- logger = get_logger(__name__)
-
-
- PROTOCOLS = {
- "HTTPS": GitHTTPSFetcher,
- "SSH": GitSSHFetcher,
- }
-
-
- def _get_megengine_home() -> str:
- """MGE_HOME setting complies with the XDG Base Directory Specification
- """
- megengine_home = os.path.expanduser(
- os.getenv(
- ENV_MGE_HOME,
- os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"),
- )
- )
- return megengine_home
-
-
- def _get_repo(
- git_host: str,
- repo_info: str,
- use_cache: bool = False,
- commit: str = None,
- protocol: str = DEFAULT_PROTOCOL,
- ) -> str:
- if protocol not in PROTOCOLS:
- raise InvalidProtocol(
- "Invalid protocol, the value should be one of {}.".format(
- ", ".join(PROTOCOLS.keys())
- )
- )
- cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
- with cd(cache_dir):
- fetcher = PROTOCOLS[protocol]
- repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit)
- return os.path.join(cache_dir, repo_dir)
-
-
- def _check_dependencies(module: types.ModuleType) -> None:
- if not hasattr(module, HUBDEPENDENCY):
- return
-
- dependencies = getattr(module, HUBDEPENDENCY)
- if not dependencies:
- return
-
- missing_deps = [m for m in dependencies if not check_module_exists(m)]
- if len(missing_deps):
- raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))
-
-
- def _init_hub(
- repo_info: str,
- git_host: str,
- use_cache: bool = True,
- commit: str = None,
- protocol: str = DEFAULT_PROTOCOL,
- ):
- """Imports hubmodule like python import
-
- :param repo_info:
- a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
- tag/branch. The default branch is ``master`` if not specified.
- Example: ``"brain_sdk/MegBrain[:hub]"``
- :param git_host:
- host address of git repo
- Example: github.com
- :param use_cache:
- whether to use locally cached code or completely re-fetch
- :param commit:
- commit id on github or gitlab
- :param protocol:
- which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
- The value should be one of HTTPS, SSH.
- :return:
- hubconf.py as a python module
- """
- cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
- os.makedirs(cache_dir, exist_ok=True)
- absolute_repo_dir = _get_repo(
- git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol
- )
- sys.path.insert(0, absolute_repo_dir)
- hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF))
- sys.path.remove(absolute_repo_dir)
-
- return hubmodule
-
-
- @functools.wraps(_init_hub)
- def import_module(*args, **kwargs):
- return _init_hub(*args, **kwargs)
-
-
- def list(
- repo_info: str,
- git_host: str = DEFAULT_GIT_HOST,
- use_cache: bool = True,
- commit: str = None,
- protocol: str = DEFAULT_PROTOCOL,
- ) -> List[str]:
- """Lists all entrypoints available in repo hubconf
-
- :param repo_info:
- a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
- tag/branch. The default branch is ``master`` if not specified.
- Example: ``"brain_sdk/MegBrain[:hub]"``
- :param git_host:
- host address of git repo
- Example: github.com
- :param use_cache:
- whether to use locally cached code or completely re-fetch
- :param commit:
- commit id on github or gitlab
- :param protocol:
- which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
- The value should be one of HTTPS, SSH.
- :return:
- all entrypoint names of the model
- """
- hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
-
- return [
- _
- for _ in dir(hubmodule)
- if not _.startswith("__") and callable(getattr(hubmodule, _))
- ]
-
-
- def load(
- repo_info: str,
- entry: str,
- *args,
- git_host: str = DEFAULT_GIT_HOST,
- use_cache: bool = True,
- commit: str = None,
- protocol: str = DEFAULT_PROTOCOL,
- **kwargs
- ) -> Any:
- """Loads model from github or gitlab repo, with pretrained weights.
-
- :param repo_info:
- a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
- tag/branch. The default branch is ``master`` if not specified.
- Example: ``"brain_sdk/MegBrain[:hub]"``
- :param entry:
- an entrypoint defined in hubconf
- :param git_host:
- host address of git repo
- Example: github.com
- :param use_cache:
- whether to use locally cached code or completely re-fetch
- :param commit:
- commit id on github or gitlab
- :param protocol:
- which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
- The value should be one of HTTPS, SSH.
- :return:
- a single model with corresponding pretrained weights.
- """
- hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
-
- if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
- raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
-
- _check_dependencies(hubmodule)
-
- module = getattr(hubmodule, entry)(*args, **kwargs)
- return module
-
-
- def help(
- repo_info: str,
- entry: str,
- git_host: str = DEFAULT_GIT_HOST,
- use_cache: bool = True,
- commit: str = None,
- protocol: str = DEFAULT_PROTOCOL,
- ) -> str:
- """This function returns docstring of entrypoint ``entry`` by following steps:
-
- 1. Pull the repo code specified by git and repo_info
- 2. Load the entry defined in repo's hubconf.py
- 3. Return docstring of function entry
-
- :param repo_info:
- a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
- tag/branch. The default branch is ``master`` if not specified.
- Example: ``"brain_sdk/MegBrain[:hub]"``
- :param entry:
- an entrypoint defined in hubconf.py
- :param git_host:
- host address of git repo
- Example: github.com
- :param use_cache:
- whether to use locally cached code or completely re-fetch
- :param commit:
- commit id on github or gitlab
- :param protocol:
- which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
- The value should be one of HTTPS, SSH.
- :return:
- docstring of entrypoint ``entry``
- """
- hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
-
- if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
- raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
-
- doc = getattr(hubmodule, entry).__doc__
- return doc
-
-
- def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
- """Loads MegEngine serialized object from the given URL.
-
- If the object is already present in ``model_dir``, it's deserialized and
- returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``.
-
- :param url: url to serialized object
- :param model_dir: dir to cache target serialized file
-
- :return: loaded object
- """
- if model_dir is None:
- model_dir = os.path.join(_get_megengine_home(), "serialized")
- os.makedirs(model_dir, exist_ok=True)
-
- parts = urlparse(url)
- filename = os.path.basename(parts.path)
-
- # use hash as prefix to avoid filename conflict from different urls
- sha256 = hashlib.sha256()
- sha256.update(url.encode())
- digest = sha256.hexdigest()[:6]
- filename = digest + "_" + filename
-
- cached_file = os.path.join(model_dir, filename)
- logger.info(
- "load_serialized_obj_from_url: download to or using cached %s", cached_file
- )
- if not os.path.exists(cached_file):
- if is_distributed():
- logger.warning(
- "Downloading serialized object in DISTRIBUTED mode\n"
- " File may be downloaded multiple times. We recommend\n"
- " users to download in single process first."
- )
- download_from_url(url, cached_file, HTTP_READ_TIMEOUT)
-
- state_dict = _mge_load_serialized(cached_file)
- return state_dict
-
-
- class pretrained:
- r"""
- Decorator which helps to download pretrained weights from the given url.
-
- For example, we can decorate a resnet18 function as follows
-
- .. code-block::
-
- @hub.pretrained("https://url/to/pretrained_resnet18.pkl")
- def resnet18(**kwargs):
- return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
-
- When decorated function is called with ``pretrained=True``, MegEngine will automatically
- download and fill the returned model with pretrained weights.
- """
-
- def __init__(self, url):
- self.url = url
-
- def __call__(self, func):
- @functools.wraps(func)
- def pretrained_model_func(
- pretrained=False, **kwargs
- ): # pylint: disable=redefined-outer-name
- model = func(**kwargs)
- if pretrained:
- weights = load_serialized_obj_from_url(self.url)
- model.load_state_dict(weights)
- return model
-
- return pretrained_model_func
-
-
- __all__ = [
- "list",
- "load",
- "help",
- "load_serialized_obj_from_url",
- "pretrained",
- "import_module",
- ]
|