You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hub.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import hashlib
  11. import os
  12. import sys
  13. import types
  14. from typing import Any, List
  15. from urllib.parse import urlparse
  16. from megengine.utils.http_download import download_from_url
  17. from ..distributed import is_distributed
  18. from ..logger import get_logger
  19. from ..serialization import load as _mge_load_serialized
  20. from .const import (
  21. DEFAULT_CACHE_DIR,
  22. DEFAULT_GIT_HOST,
  23. DEFAULT_PROTOCOL,
  24. ENV_MGE_HOME,
  25. ENV_XDG_CACHE_HOME,
  26. HUBCONF,
  27. HUBDEPENDENCY,
  28. )
  29. from .exceptions import InvalidProtocol
  30. from .fetcher import GitHTTPSFetcher, GitSSHFetcher
  31. from .tools import cd, check_module_exists, load_module
  32. logger = get_logger(__name__)
  33. PROTOCOLS = {
  34. "HTTPS": GitHTTPSFetcher,
  35. "SSH": GitSSHFetcher,
  36. }
  37. def _get_megengine_home() -> str:
  38. r"""MGE_HOME setting complies with the XDG Base Directory Specification"""
  39. megengine_home = os.path.expanduser(
  40. os.getenv(
  41. ENV_MGE_HOME,
  42. os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"),
  43. )
  44. )
  45. return megengine_home
  46. def _get_repo(
  47. git_host: str,
  48. repo_info: str,
  49. use_cache: bool = False,
  50. commit: str = None,
  51. protocol: str = DEFAULT_PROTOCOL,
  52. ) -> str:
  53. if protocol not in PROTOCOLS:
  54. raise InvalidProtocol(
  55. "Invalid protocol, the value should be one of {}.".format(
  56. ", ".join(PROTOCOLS.keys())
  57. )
  58. )
  59. cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
  60. with cd(cache_dir):
  61. fetcher = PROTOCOLS[protocol]
  62. repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit)
  63. return os.path.join(cache_dir, repo_dir)
  64. def _check_dependencies(module: types.ModuleType) -> None:
  65. if not hasattr(module, HUBDEPENDENCY):
  66. return
  67. dependencies = getattr(module, HUBDEPENDENCY)
  68. if not dependencies:
  69. return
  70. missing_deps = [m for m in dependencies if not check_module_exists(m)]
  71. if len(missing_deps):
  72. raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))
  73. def _init_hub(
  74. repo_info: str,
  75. git_host: str,
  76. use_cache: bool = True,
  77. commit: str = None,
  78. protocol: str = DEFAULT_PROTOCOL,
  79. ):
  80. r"""Imports hubmodule like python import.
  81. Args:
  82. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  83. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  84. git_host: host address of git repo. Eg: github.com
  85. use_cache: whether to use locally cached code or completely re-fetch.
  86. commit: commit id on github or gitlab.
  87. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  88. The value should be one of HTTPS, SSH.
  89. Returns:
  90. a python module.
  91. """
  92. cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
  93. os.makedirs(cache_dir, exist_ok=True)
  94. absolute_repo_dir = _get_repo(
  95. git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol
  96. )
  97. sys.path.insert(0, absolute_repo_dir)
  98. hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF))
  99. sys.path.remove(absolute_repo_dir)
  100. return hubmodule
  101. @functools.wraps(_init_hub)
  102. def import_module(*args, **kwargs):
  103. return _init_hub(*args, **kwargs)
  104. def list(
  105. repo_info: str,
  106. git_host: str = DEFAULT_GIT_HOST,
  107. use_cache: bool = True,
  108. commit: str = None,
  109. protocol: str = DEFAULT_PROTOCOL,
  110. ) -> List[str]:
  111. r"""Lists all entrypoints available in repo hubconf.
  112. Args:
  113. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  114. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  115. git_host: host address of git repo. Eg: github.com
  116. use_cache: whether to use locally cached code or completely re-fetch.
  117. commit: commit id on github or gitlab.
  118. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  119. The value should be one of HTTPS, SSH.
  120. Returns:
  121. all entrypoint names of the model.
  122. """
  123. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  124. return [
  125. _
  126. for _ in dir(hubmodule)
  127. if not _.startswith("__") and callable(getattr(hubmodule, _))
  128. ]
  129. def load(
  130. repo_info: str,
  131. entry: str,
  132. *args,
  133. git_host: str = DEFAULT_GIT_HOST,
  134. use_cache: bool = True,
  135. commit: str = None,
  136. protocol: str = DEFAULT_PROTOCOL,
  137. **kwargs
  138. ) -> Any:
  139. r"""Loads model from github or gitlab repo, with pretrained weights.
  140. Args:
  141. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  142. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  143. entry: an entrypoint defined in hubconf.
  144. git_host: host address of git repo. Eg: github.com
  145. use_cache: whether to use locally cached code or completely re-fetch.
  146. commit: commit id on github or gitlab.
  147. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  148. The value should be one of HTTPS, SSH.
  149. Returns:
  150. a single model with corresponding pretrained weights.
  151. """
  152. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  153. if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
  154. raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
  155. _check_dependencies(hubmodule)
  156. module = getattr(hubmodule, entry)(*args, **kwargs)
  157. return module
  158. def help(
  159. repo_info: str,
  160. entry: str,
  161. git_host: str = DEFAULT_GIT_HOST,
  162. use_cache: bool = True,
  163. commit: str = None,
  164. protocol: str = DEFAULT_PROTOCOL,
  165. ) -> str:
  166. r"""This function returns docstring of entrypoint ``entry`` by following steps:
  167. 1. Pull the repo code specified by git and repo_info.
  168. 2. Load the entry defined in repo's hubconf.py
  169. 3. Return docstring of function entry.
  170. Args:
  171. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  172. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  173. entry: an entrypoint defined in hubconf.py
  174. git_host: host address of git repo. Eg: github.com
  175. use_cache: whether to use locally cached code or completely re-fetch.
  176. commit: commit id on github or gitlab.
  177. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  178. The value should be one of HTTPS, SSH.
  179. Returns:
  180. docstring of entrypoint ``entry``.
  181. """
  182. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  183. if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
  184. raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
  185. doc = getattr(hubmodule, entry).__doc__
  186. return doc
  187. def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
  188. """Loads MegEngine serialized object from the given URL.
  189. If the object is already present in ``model_dir``, it's deserialized and
  190. returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``.
  191. Args:
  192. url: url to serialized object.
  193. model_dir: dir to cache target serialized file.
  194. Returns:
  195. loaded object.
  196. """
  197. if model_dir is None:
  198. model_dir = os.path.join(_get_megengine_home(), "serialized")
  199. os.makedirs(model_dir, exist_ok=True)
  200. parts = urlparse(url)
  201. filename = os.path.basename(parts.path)
  202. # use hash as prefix to avoid filename conflict from different urls
  203. sha256 = hashlib.sha256()
  204. sha256.update(url.encode())
  205. digest = sha256.hexdigest()[:6]
  206. filename = digest + "_" + filename
  207. cached_file = os.path.join(model_dir, filename)
  208. logger.info(
  209. "load_serialized_obj_from_url: download to or using cached %s", cached_file
  210. )
  211. if not os.path.exists(cached_file):
  212. if is_distributed():
  213. logger.warning(
  214. "Downloading serialized object in DISTRIBUTED mode\n"
  215. " File may be downloaded multiple times. We recommend\n"
  216. " users to download in single process first."
  217. )
  218. download_from_url(url, cached_file)
  219. state_dict = _mge_load_serialized(cached_file)
  220. return state_dict
  221. class pretrained:
  222. r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s).
  223. For example, we can decorate a resnet18 function as follows
  224. .. code-block::
  225. @hub.pretrained("https://url/to/pretrained_resnet18.pkl")
  226. def resnet18(**kwargs):
  227. Returns:
  228. When decorated function is called with ``pretrained=True``, MegEngine will automatically
  229. download and fill the returned model with pretrained weights.
  230. """
  231. def __init__(self, url):
  232. self.url = url
  233. def __call__(self, func):
  234. @functools.wraps(func)
  235. def pretrained_model_func(
  236. pretrained=False, **kwargs
  237. ): # pylint: disable=redefined-outer-name
  238. model = func(**kwargs)
  239. if pretrained:
  240. weights = load_serialized_obj_from_url(self.url)
  241. model.load_state_dict(weights)
  242. return model
  243. return pretrained_model_func
  244. __all__ = [
  245. "list",
  246. "load",
  247. "help",
  248. "load_serialized_obj_from_url",
  249. "pretrained",
  250. "import_module",
  251. ]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台