You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hub.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import hashlib
  11. import os
  12. import sys
  13. import types
  14. from typing import Any, List
  15. from urllib.parse import urlparse
  16. from megengine.utils.http_download import download_from_url
  17. from ..distributed import is_distributed
  18. from ..logger import get_logger
  19. from ..serialization import load as _mge_load_serialized
  20. from .const import (
  21. DEFAULT_CACHE_DIR,
  22. DEFAULT_GIT_HOST,
  23. DEFAULT_PROTOCOL,
  24. ENV_MGE_HOME,
  25. ENV_XDG_CACHE_HOME,
  26. HTTP_READ_TIMEOUT,
  27. HUBCONF,
  28. HUBDEPENDENCY,
  29. )
  30. from .exceptions import InvalidProtocol
  31. from .fetcher import GitHTTPSFetcher, GitSSHFetcher
  32. from .tools import cd, check_module_exists, load_module
  33. logger = get_logger(__name__)
  34. PROTOCOLS = {
  35. "HTTPS": GitHTTPSFetcher,
  36. "SSH": GitSSHFetcher,
  37. }
  38. def _get_megengine_home() -> str:
  39. r"""MGE_HOME setting complies with the XDG Base Directory Specification"""
  40. megengine_home = os.path.expanduser(
  41. os.getenv(
  42. ENV_MGE_HOME,
  43. os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"),
  44. )
  45. )
  46. return megengine_home
  47. def _get_repo(
  48. git_host: str,
  49. repo_info: str,
  50. use_cache: bool = False,
  51. commit: str = None,
  52. protocol: str = DEFAULT_PROTOCOL,
  53. ) -> str:
  54. if protocol not in PROTOCOLS:
  55. raise InvalidProtocol(
  56. "Invalid protocol, the value should be one of {}.".format(
  57. ", ".join(PROTOCOLS.keys())
  58. )
  59. )
  60. cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
  61. with cd(cache_dir):
  62. fetcher = PROTOCOLS[protocol]
  63. repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit)
  64. return os.path.join(cache_dir, repo_dir)
  65. def _check_dependencies(module: types.ModuleType) -> None:
  66. if not hasattr(module, HUBDEPENDENCY):
  67. return
  68. dependencies = getattr(module, HUBDEPENDENCY)
  69. if not dependencies:
  70. return
  71. missing_deps = [m for m in dependencies if not check_module_exists(m)]
  72. if len(missing_deps):
  73. raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))
  74. def _init_hub(
  75. repo_info: str,
  76. git_host: str,
  77. use_cache: bool = True,
  78. commit: str = None,
  79. protocol: str = DEFAULT_PROTOCOL,
  80. ):
  81. r"""Imports hubmodule like python import.
  82. Args:
  83. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  84. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  85. git_host: host address of git repo. Eg: github.com
  86. use_cache: whether to use locally cached code or completely re-fetch.
  87. commit: commit id on github or gitlab.
  88. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  89. The value should be one of HTTPS, SSH.
  90. Returns:
  91. a python module.
  92. """
  93. cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
  94. os.makedirs(cache_dir, exist_ok=True)
  95. absolute_repo_dir = _get_repo(
  96. git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol
  97. )
  98. sys.path.insert(0, absolute_repo_dir)
  99. hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF))
  100. sys.path.remove(absolute_repo_dir)
  101. return hubmodule
  102. @functools.wraps(_init_hub)
  103. def import_module(*args, **kwargs):
  104. return _init_hub(*args, **kwargs)
  105. def list(
  106. repo_info: str,
  107. git_host: str = DEFAULT_GIT_HOST,
  108. use_cache: bool = True,
  109. commit: str = None,
  110. protocol: str = DEFAULT_PROTOCOL,
  111. ) -> List[str]:
  112. r"""Lists all entrypoints available in repo hubconf.
  113. Args:
  114. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  115. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  116. git_host: host address of git repo. Eg: github.com
  117. use_cache: whether to use locally cached code or completely re-fetch.
  118. commit: commit id on github or gitlab.
  119. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  120. The value should be one of HTTPS, SSH.
  121. Returns:
  122. all entrypoint names of the model.
  123. """
  124. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  125. return [
  126. _
  127. for _ in dir(hubmodule)
  128. if not _.startswith("__") and callable(getattr(hubmodule, _))
  129. ]
  130. def load(
  131. repo_info: str,
  132. entry: str,
  133. *args,
  134. git_host: str = DEFAULT_GIT_HOST,
  135. use_cache: bool = True,
  136. commit: str = None,
  137. protocol: str = DEFAULT_PROTOCOL,
  138. **kwargs
  139. ) -> Any:
  140. r"""Loads model from github or gitlab repo, with pretrained weights.
  141. Args:
  142. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  143. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  144. entry: an entrypoint defined in hubconf.
  145. git_host: host address of git repo. Eg: github.com
  146. use_cache: whether to use locally cached code or completely re-fetch.
  147. commit: commit id on github or gitlab.
  148. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  149. The value should be one of HTTPS, SSH.
  150. Returns:
  151. a single model with corresponding pretrained weights.
  152. """
  153. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  154. if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
  155. raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
  156. _check_dependencies(hubmodule)
  157. module = getattr(hubmodule, entry)(*args, **kwargs)
  158. return module
  159. def help(
  160. repo_info: str,
  161. entry: str,
  162. git_host: str = DEFAULT_GIT_HOST,
  163. use_cache: bool = True,
  164. commit: str = None,
  165. protocol: str = DEFAULT_PROTOCOL,
  166. ) -> str:
  167. r"""This function returns docstring of entrypoint ``entry`` by following steps:
  168. 1. Pull the repo code specified by git and repo_info.
  169. 2. Load the entry defined in repo's hubconf.py
  170. 3. Return docstring of function entry.
  171. Args:
  172. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  173. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  174. entry: an entrypoint defined in hubconf.py
  175. git_host: host address of git repo. Eg: github.com
  176. use_cache: whether to use locally cached code or completely re-fetch.
  177. commit: commit id on github or gitlab.
  178. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  179. The value should be one of HTTPS, SSH.
  180. Returns:
  181. docstring of entrypoint ``entry``.
  182. """
  183. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  184. if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
  185. raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
  186. doc = getattr(hubmodule, entry).__doc__
  187. return doc
  188. def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
  189. """Loads MegEngine serialized object from the given URL.
  190. If the object is already present in ``model_dir``, it's deserialized and
  191. returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``.
  192. Args:
  193. url: url to serialized object.
  194. model_dir: dir to cache target serialized file.
  195. Returns:
  196. loaded object.
  197. """
  198. if model_dir is None:
  199. model_dir = os.path.join(_get_megengine_home(), "serialized")
  200. os.makedirs(model_dir, exist_ok=True)
  201. parts = urlparse(url)
  202. filename = os.path.basename(parts.path)
  203. # use hash as prefix to avoid filename conflict from different urls
  204. sha256 = hashlib.sha256()
  205. sha256.update(url.encode())
  206. digest = sha256.hexdigest()[:6]
  207. filename = digest + "_" + filename
  208. cached_file = os.path.join(model_dir, filename)
  209. logger.info(
  210. "load_serialized_obj_from_url: download to or using cached %s", cached_file
  211. )
  212. if not os.path.exists(cached_file):
  213. if is_distributed():
  214. logger.warning(
  215. "Downloading serialized object in DISTRIBUTED mode\n"
  216. " File may be downloaded multiple times. We recommend\n"
  217. " users to download in single process first."
  218. )
  219. download_from_url(url, cached_file, HTTP_READ_TIMEOUT)
  220. state_dict = _mge_load_serialized(cached_file)
  221. return state_dict
  222. class pretrained:
  223. r"""Decorator which helps to download pretrained weights from the given url.
  224. For example, we can decorate a resnet18 function as follows
  225. .. code-block::
  226. @hub.pretrained("https://url/to/pretrained_resnet18.pkl")
  227. def resnet18(**kwargs):
  228. Returns:
  229. When decorated function is called with ``pretrained=True``, MegEngine will automatically
  230. download and fill the returned model with pretrained weights.
  231. """
  232. def __init__(self, url):
  233. self.url = url
  234. def __call__(self, func):
  235. @functools.wraps(func)
  236. def pretrained_model_func(
  237. pretrained=False, **kwargs
  238. ): # pylint: disable=redefined-outer-name
  239. model = func(**kwargs)
  240. if pretrained:
  241. weights = load_serialized_obj_from_url(self.url)
  242. model.load_state_dict(weights)
  243. return model
  244. return pretrained_model_func
  245. __all__ = [
  246. "list",
  247. "load",
  248. "help",
  249. "load_serialized_obj_from_url",
  250. "pretrained",
  251. "import_module",
  252. ]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台