You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hub.py 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # -*- coding: utf-8 -*-
  2. import functools
  3. import hashlib
  4. import os
  5. import sys
  6. import types
  7. from typing import Any, List
  8. from urllib.parse import urlparse
  9. from megengine.utils.http_download import download_from_url
  10. from ..distributed import is_distributed
  11. from ..logger import get_logger
  12. from ..serialization import load as _mge_load_serialized
  13. from .const import (
  14. DEFAULT_CACHE_DIR,
  15. DEFAULT_GIT_HOST,
  16. DEFAULT_PROTOCOL,
  17. ENV_MGE_HOME,
  18. ENV_XDG_CACHE_HOME,
  19. HUBCONF,
  20. HUBDEPENDENCY,
  21. )
  22. from .exceptions import InvalidProtocol
  23. from .fetcher import GitHTTPSFetcher, GitSSHFetcher
  24. from .tools import cd, check_module_exists, load_module
  25. logger = get_logger(__name__)
  26. PROTOCOLS = {
  27. "HTTPS": GitHTTPSFetcher,
  28. "SSH": GitSSHFetcher,
  29. }
  30. def _get_megengine_home() -> str:
  31. r"""MGE_HOME setting complies with the XDG Base Directory Specification"""
  32. megengine_home = os.path.expanduser(
  33. os.getenv(
  34. ENV_MGE_HOME,
  35. os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"),
  36. )
  37. )
  38. return megengine_home
  39. def _get_repo(
  40. git_host: str,
  41. repo_info: str,
  42. use_cache: bool = False,
  43. commit: str = None,
  44. protocol: str = DEFAULT_PROTOCOL,
  45. ) -> str:
  46. if protocol not in PROTOCOLS:
  47. raise InvalidProtocol(
  48. "Invalid protocol, the value should be one of {}.".format(
  49. ", ".join(PROTOCOLS.keys())
  50. )
  51. )
  52. cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
  53. with cd(cache_dir):
  54. fetcher = PROTOCOLS[protocol]
  55. repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit)
  56. return os.path.join(cache_dir, repo_dir)
  57. def _check_dependencies(module: types.ModuleType) -> None:
  58. if not hasattr(module, HUBDEPENDENCY):
  59. return
  60. dependencies = getattr(module, HUBDEPENDENCY)
  61. if not dependencies:
  62. return
  63. missing_deps = [m for m in dependencies if not check_module_exists(m)]
  64. if len(missing_deps):
  65. raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))
  66. def _init_hub(
  67. repo_info: str,
  68. git_host: str,
  69. use_cache: bool = True,
  70. commit: str = None,
  71. protocol: str = DEFAULT_PROTOCOL,
  72. ):
  73. r"""Imports hubmodule like python import.
  74. Args:
  75. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  76. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  77. git_host: host address of git repo. Eg: github.com
  78. use_cache: whether to use locally cached code or completely re-fetch.
  79. commit: commit id on github or gitlab.
  80. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  81. The value should be one of HTTPS, SSH.
  82. Returns:
  83. a python module.
  84. """
  85. cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
  86. os.makedirs(cache_dir, exist_ok=True)
  87. absolute_repo_dir = _get_repo(
  88. git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol
  89. )
  90. sys.path.insert(0, absolute_repo_dir)
  91. hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF))
  92. sys.path.remove(absolute_repo_dir)
  93. return hubmodule
  94. @functools.wraps(_init_hub)
  95. def import_module(*args, **kwargs):
  96. return _init_hub(*args, **kwargs)
  97. def list(
  98. repo_info: str,
  99. git_host: str = DEFAULT_GIT_HOST,
  100. use_cache: bool = True,
  101. commit: str = None,
  102. protocol: str = DEFAULT_PROTOCOL,
  103. ) -> List[str]:
  104. r"""Lists all entrypoints available in repo hubconf.
  105. Args:
  106. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  107. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  108. git_host: host address of git repo. Eg: github.com
  109. use_cache: whether to use locally cached code or completely re-fetch.
  110. commit: commit id on github or gitlab.
  111. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  112. The value should be one of HTTPS, SSH.
  113. Returns:
  114. all entrypoint names of the model.
  115. """
  116. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  117. return [
  118. _
  119. for _ in dir(hubmodule)
  120. if not _.startswith("__") and callable(getattr(hubmodule, _))
  121. ]
  122. def load(
  123. repo_info: str,
  124. entry: str,
  125. *args,
  126. git_host: str = DEFAULT_GIT_HOST,
  127. use_cache: bool = True,
  128. commit: str = None,
  129. protocol: str = DEFAULT_PROTOCOL,
  130. **kwargs
  131. ) -> Any:
  132. r"""Loads model from github or gitlab repo, with pretrained weights.
  133. Args:
  134. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  135. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  136. entry: an entrypoint defined in hubconf.
  137. git_host: host address of git repo. Eg: github.com
  138. use_cache: whether to use locally cached code or completely re-fetch.
  139. commit: commit id on github or gitlab.
  140. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  141. The value should be one of HTTPS, SSH.
  142. Returns:
  143. a single model with corresponding pretrained weights.
  144. """
  145. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  146. if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
  147. raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
  148. _check_dependencies(hubmodule)
  149. module = getattr(hubmodule, entry)(*args, **kwargs)
  150. return module
  151. def help(
  152. repo_info: str,
  153. entry: str,
  154. git_host: str = DEFAULT_GIT_HOST,
  155. use_cache: bool = True,
  156. commit: str = None,
  157. protocol: str = DEFAULT_PROTOCOL,
  158. ) -> str:
  159. r"""This function returns docstring of entrypoint ``entry`` by following steps:
  160. 1. Pull the repo code specified by git and repo_info.
  161. 2. Load the entry defined in repo's hubconf.py
  162. 3. Return docstring of function entry.
  163. Args:
  164. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  165. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  166. entry: an entrypoint defined in hubconf.py
  167. git_host: host address of git repo. Eg: github.com
  168. use_cache: whether to use locally cached code or completely re-fetch.
  169. commit: commit id on github or gitlab.
  170. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  171. The value should be one of HTTPS, SSH.
  172. Returns:
  173. docstring of entrypoint ``entry``.
  174. """
  175. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  176. if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
  177. raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
  178. doc = getattr(hubmodule, entry).__doc__
  179. return doc
  180. def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
  181. """Loads MegEngine serialized object from the given URL.
  182. If the object is already present in ``model_dir``, it's deserialized and
  183. returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``.
  184. Args:
  185. url: url to serialized object.
  186. model_dir: dir to cache target serialized file.
  187. Returns:
  188. loaded object.
  189. """
  190. if model_dir is None:
  191. model_dir = os.path.join(_get_megengine_home(), "serialized")
  192. os.makedirs(model_dir, exist_ok=True)
  193. parts = urlparse(url)
  194. filename = os.path.basename(parts.path)
  195. # use hash as prefix to avoid filename conflict from different urls
  196. sha256 = hashlib.sha256()
  197. sha256.update(url.encode())
  198. digest = sha256.hexdigest()[:6]
  199. filename = digest + "_" + filename
  200. cached_file = os.path.join(model_dir, filename)
  201. logger.info(
  202. "load_serialized_obj_from_url: download to or using cached %s", cached_file
  203. )
  204. if not os.path.exists(cached_file):
  205. if is_distributed():
  206. logger.warning(
  207. "Downloading serialized object in DISTRIBUTED mode\n"
  208. " File may be downloaded multiple times. We recommend\n"
  209. " users to download in single process first."
  210. )
  211. download_from_url(url, cached_file)
  212. state_dict = _mge_load_serialized(cached_file)
  213. return state_dict
  214. class pretrained:
  215. r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s).
  216. For example, we can decorate a resnet18 function as follows
  217. .. code-block::
  218. @hub.pretrained("https://url/to/pretrained_resnet18.pkl")
  219. def resnet18(**kwargs):
  220. Returns:
  221. When decorated function is called with ``pretrained=True``, MegEngine will automatically
  222. download and fill the returned model with pretrained weights.
  223. """
  224. def __init__(self, url):
  225. self.url = url
  226. def __call__(self, func):
  227. @functools.wraps(func)
  228. def pretrained_model_func(
  229. pretrained=False, **kwargs
  230. ): # pylint: disable=redefined-outer-name
  231. model = func(**kwargs)
  232. if pretrained:
  233. weights = load_serialized_obj_from_url(self.url)
  234. model.load_state_dict(weights)
  235. return model
  236. return pretrained_model_func
  237. __all__ = [
  238. "list",
  239. "load",
  240. "help",
  241. "load_serialized_obj_from_url",
  242. "pretrained",
  243. "import_module",
  244. ]