You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hub.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import hashlib
  11. import os
  12. import sys
  13. import types
  14. from typing import Any, List
  15. from urllib.parse import urlparse
  16. from megengine.utils.http_download import download_from_url
  17. from ..core.serialization import load as _mge_load_serialized
  18. from ..distributed import is_distributed
  19. from ..logger import get_logger
  20. from .const import (
  21. DEFAULT_CACHE_DIR,
  22. DEFAULT_GIT_HOST,
  23. DEFAULT_PROTOCOL,
  24. ENV_MGE_HOME,
  25. ENV_XDG_CACHE_HOME,
  26. HTTP_READ_TIMEOUT,
  27. HUBCONF,
  28. HUBDEPENDENCY,
  29. )
  30. from .exceptions import InvalidProtocol
  31. from .fetcher import GitHTTPSFetcher, GitSSHFetcher
  32. from .tools import cd, check_module_exists, load_module
  33. logger = get_logger(__name__)
  34. PROTOCOLS = {
  35. "HTTPS": GitHTTPSFetcher,
  36. "SSH": GitSSHFetcher,
  37. }
  38. def _get_megengine_home() -> str:
  39. """MGE_HOME setting complies with the XDG Base Directory Specification
  40. """
  41. megengine_home = os.path.expanduser(
  42. os.getenv(
  43. ENV_MGE_HOME,
  44. os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"),
  45. )
  46. )
  47. return megengine_home
  48. def _get_repo(
  49. git_host: str,
  50. repo_info: str,
  51. use_cache: bool = False,
  52. commit: str = None,
  53. protocol: str = DEFAULT_PROTOCOL,
  54. ) -> str:
  55. if protocol not in PROTOCOLS:
  56. raise InvalidProtocol(
  57. "Invalid protocol, the value should be one of {}.".format(
  58. ", ".join(PROTOCOLS.keys())
  59. )
  60. )
  61. cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
  62. with cd(cache_dir):
  63. fetcher = PROTOCOLS[protocol]
  64. repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit)
  65. return os.path.join(cache_dir, repo_dir)
  66. def _check_dependencies(module: types.ModuleType) -> None:
  67. if not hasattr(module, HUBDEPENDENCY):
  68. return
  69. dependencies = getattr(module, HUBDEPENDENCY)
  70. if not dependencies:
  71. return
  72. missing_deps = [m for m in dependencies if not check_module_exists(m)]
  73. if len(missing_deps):
  74. raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))
  75. def _init_hub(
  76. repo_info: str,
  77. git_host: str,
  78. use_cache: bool = True,
  79. commit: str = None,
  80. protocol: str = DEFAULT_PROTOCOL,
  81. ):
  82. """Imports hubmodule like python import
  83. :param repo_info:
  84. a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  85. tag/branch. The default branch is ``master`` if not specified.
  86. Example: ``"brain_sdk/MegBrain[:hub]"``
  87. :param git_host:
  88. host address of git repo
  89. Example: github.com
  90. :param use_cache:
  91. whether to use locally cached code or completely re-fetch
  92. :param commit:
  93. commit id on github or gitlab
  94. :param protocol:
  95. which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  96. The value should be one of HTTPS, SSH.
  97. :return:
  98. hubconf.py as a python module
  99. """
  100. cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
  101. os.makedirs(cache_dir, exist_ok=True)
  102. absolute_repo_dir = _get_repo(
  103. git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol
  104. )
  105. sys.path.insert(0, absolute_repo_dir)
  106. hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF))
  107. sys.path.remove(absolute_repo_dir)
  108. return hubmodule
  109. @functools.wraps(_init_hub)
  110. def import_module(*args, **kwargs):
  111. return _init_hub(*args, **kwargs)
  112. def list(
  113. repo_info: str,
  114. git_host: str = DEFAULT_GIT_HOST,
  115. use_cache: bool = True,
  116. commit: str = None,
  117. protocol: str = DEFAULT_PROTOCOL,
  118. ) -> List[str]:
  119. """Lists all entrypoints available in repo hubconf
  120. :param repo_info:
  121. a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  122. tag/branch. The default branch is ``master`` if not specified.
  123. Example: ``"brain_sdk/MegBrain[:hub]"``
  124. :param git_host:
  125. host address of git repo
  126. Example: github.com
  127. :param use_cache:
  128. whether to use locally cached code or completely re-fetch
  129. :param commit:
  130. commit id on github or gitlab
  131. :param protocol:
  132. which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  133. The value should be one of HTTPS, SSH.
  134. :return:
  135. all entrypoint names of the model
  136. """
  137. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  138. return [
  139. _
  140. for _ in dir(hubmodule)
  141. if not _.startswith("__") and callable(getattr(hubmodule, _))
  142. ]
  143. def load(
  144. repo_info: str,
  145. entry: str,
  146. *args,
  147. git_host: str = DEFAULT_GIT_HOST,
  148. use_cache: bool = True,
  149. commit: str = None,
  150. protocol: str = DEFAULT_PROTOCOL,
  151. **kwargs
  152. ) -> Any:
  153. """Loads model from github or gitlab repo, with pretrained weights.
  154. :param repo_info:
  155. a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  156. tag/branch. The default branch is ``master`` if not specified.
  157. Example: ``"brain_sdk/MegBrain[:hub]"``
  158. :param entry:
  159. an entrypoint defined in hubconf
  160. :param git_host:
  161. host address of git repo
  162. Example: github.com
  163. :param use_cache:
  164. whether to use locally cached code or completely re-fetch
  165. :param commit:
  166. commit id on github or gitlab
  167. :param protocol:
  168. which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  169. The value should be one of HTTPS, SSH.
  170. :return:
  171. a single model with corresponding pretrained weights.
  172. """
  173. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  174. if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
  175. raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
  176. _check_dependencies(hubmodule)
  177. module = getattr(hubmodule, entry)(*args, **kwargs)
  178. return module
  179. def help(
  180. repo_info: str,
  181. entry: str,
  182. git_host: str = DEFAULT_GIT_HOST,
  183. use_cache: bool = True,
  184. commit: str = None,
  185. protocol: str = DEFAULT_PROTOCOL,
  186. ) -> str:
  187. """This function returns docstring of entrypoint ``entry`` by following steps:
  188. 1. Pull the repo code specified by git and repo_info
  189. 2. Load the entry defined in repo's hubconf.py
  190. 3. Return docstring of function entry
  191. :param repo_info:
  192. a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  193. tag/branch. The default branch is ``master`` if not specified.
  194. Example: ``"brain_sdk/MegBrain[:hub]"``
  195. :param entry:
  196. an entrypoint defined in hubconf.py
  197. :param git_host:
  198. host address of git repo
  199. Example: github.com
  200. :param use_cache:
  201. whether to use locally cached code or completely re-fetch
  202. :param commit:
  203. commit id on github or gitlab
  204. :param protocol:
  205. which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  206. The value should be one of HTTPS, SSH.
  207. :return:
  208. docstring of entrypoint ``entry``
  209. """
  210. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  211. if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
  212. raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
  213. doc = getattr(hubmodule, entry).__doc__
  214. return doc
  215. def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
  216. """Loads MegEngine serialized object from the given URL.
  217. If the object is already present in ``model_dir``, it's deserialized and
  218. returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``.
  219. :param url: url to serialized object
  220. :param model_dir: dir to cache target serialized file
  221. :return: loaded object
  222. """
  223. if model_dir is None:
  224. model_dir = os.path.join(_get_megengine_home(), "serialized")
  225. os.makedirs(model_dir, exist_ok=True)
  226. parts = urlparse(url)
  227. filename = os.path.basename(parts.path)
  228. # use hash as prefix to avoid filename conflict from different urls
  229. sha256 = hashlib.sha256()
  230. sha256.update(url.encode())
  231. digest = sha256.hexdigest()[:6]
  232. filename = digest + "_" + filename
  233. cached_file = os.path.join(model_dir, filename)
  234. logger.info(
  235. "load_serialized_obj_from_url: download to or using cached %s", cached_file
  236. )
  237. if not os.path.exists(cached_file):
  238. if is_distributed():
  239. logger.warning(
  240. "Downloading serialized object in DISTRIBUTED mode\n"
  241. " File may be downloaded multiple times. We recommend\n"
  242. " users to download in single process first."
  243. )
  244. download_from_url(url, cached_file, HTTP_READ_TIMEOUT)
  245. state_dict = _mge_load_serialized(cached_file)
  246. return state_dict
  247. class pretrained:
  248. r"""
  249. Decorator which helps to download pretrained weights from the given url.
  250. For example, we can decorate a resnet18 function as follows
  251. .. code-block::
  252. @hub.pretrained("https://url/to/pretrained_resnet18.pkl")
  253. def resnet18(**kwargs):
  254. return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
  255. When decorated function is called with ``pretrained=True``, MegEngine will automatically
  256. download and fill the returned model with pretrained weights.
  257. """
  258. def __init__(self, url):
  259. self.url = url
  260. def __call__(self, func):
  261. @functools.wraps(func)
  262. def pretrained_model_func(
  263. pretrained=False, **kwargs
  264. ): # pylint: disable=redefined-outer-name
  265. model = func(**kwargs)
  266. if pretrained:
  267. weights = load_serialized_obj_from_url(self.url)
  268. model.load_state_dict(weights)
  269. return model
  270. return pretrained_model_func
  271. __all__ = [
  272. "list",
  273. "load",
  274. "help",
  275. "load_serialized_obj_from_url",
  276. "pretrained",
  277. "import_module",
  278. ]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台