You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import hashlib
  11. import os
  12. import sys
  13. import types
  14. from typing import Any, List
  15. from urllib.parse import urlparse
  16. from megengine.utils.http_download import download_from_url
  17. from ..distributed import is_distributed
  18. from ..logger import get_logger
  19. from ..serialization import load as _mge_load_serialized
  20. from .const import (
  21. DEFAULT_CACHE_DIR,
  22. DEFAULT_GIT_HOST,
  23. DEFAULT_PROTOCOL,
  24. ENV_MGE_HOME,
  25. ENV_XDG_CACHE_HOME,
  26. HTTP_READ_TIMEOUT,
  27. HUBCONF,
  28. HUBDEPENDENCY,
  29. )
  30. from .exceptions import InvalidProtocol
  31. from .fetcher import GitHTTPSFetcher, GitSSHFetcher
  32. from .tools import cd, check_module_exists, load_module
  33. logger = get_logger(__name__)
  34. PROTOCOLS = {
  35. "HTTPS": GitHTTPSFetcher,
  36. "SSH": GitSSHFetcher,
  37. }
  38. def _get_megengine_home() -> str:
  39. """
  40. MGE_HOME setting complies with the XDG Base Directory Specification
  41. """
  42. megengine_home = os.path.expanduser(
  43. os.getenv(
  44. ENV_MGE_HOME,
  45. os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"),
  46. )
  47. )
  48. return megengine_home
  49. def _get_repo(
  50. git_host: str,
  51. repo_info: str,
  52. use_cache: bool = False,
  53. commit: str = None,
  54. protocol: str = DEFAULT_PROTOCOL,
  55. ) -> str:
  56. if protocol not in PROTOCOLS:
  57. raise InvalidProtocol(
  58. "Invalid protocol, the value should be one of {}.".format(
  59. ", ".join(PROTOCOLS.keys())
  60. )
  61. )
  62. cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
  63. with cd(cache_dir):
  64. fetcher = PROTOCOLS[protocol]
  65. repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit)
  66. return os.path.join(cache_dir, repo_dir)
  67. def _check_dependencies(module: types.ModuleType) -> None:
  68. if not hasattr(module, HUBDEPENDENCY):
  69. return
  70. dependencies = getattr(module, HUBDEPENDENCY)
  71. if not dependencies:
  72. return
  73. missing_deps = [m for m in dependencies if not check_module_exists(m)]
  74. if len(missing_deps):
  75. raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))
  76. def _init_hub(
  77. repo_info: str,
  78. git_host: str,
  79. use_cache: bool = True,
  80. commit: str = None,
  81. protocol: str = DEFAULT_PROTOCOL,
  82. ):
  83. """
  84. Imports hubmodule like python import.
  85. :param repo_info:
  86. a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  87. tag/branch. The default branch is ``master`` if not specified.
  88. Example: ``"brain_sdk/MegBrain[:hub]"``
  89. :param git_host:
  90. host address of git repo.
  91. Example: github.com
  92. :param use_cache:
  93. whether to use locally cached code or completely re-fetch.
  94. :param commit:
  95. commit id on github or gitlab.
  96. :param protocol:
  97. which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  98. The value should be one of HTTPS, SSH.
  99. :return:
  100. a python module.
  101. """
  102. cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
  103. os.makedirs(cache_dir, exist_ok=True)
  104. absolute_repo_dir = _get_repo(
  105. git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol
  106. )
  107. sys.path.insert(0, absolute_repo_dir)
  108. hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF))
  109. sys.path.remove(absolute_repo_dir)
  110. return hubmodule
  111. @functools.wraps(_init_hub)
  112. def import_module(*args, **kwargs):
  113. return _init_hub(*args, **kwargs)
  114. def list(
  115. repo_info: str,
  116. git_host: str = DEFAULT_GIT_HOST,
  117. use_cache: bool = True,
  118. commit: str = None,
  119. protocol: str = DEFAULT_PROTOCOL,
  120. ) -> List[str]:
  121. """
  122. Lists all entrypoints available in repo hubconf.
  123. :param repo_info:
  124. a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  125. tag/branch. The default branch is ``master`` if not specified.
  126. Example: ``"brain_sdk/MegBrain[:hub]"``
  127. :param git_host:
  128. host address of git repo.
  129. Example: github.com
  130. :param use_cache:
  131. whether to use locally cached code or completely re-fetch.
  132. :param commit:
  133. commit id on github or gitlab.
  134. :param protocol:
  135. which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  136. The value should be one of HTTPS, SSH.
  137. :return:
  138. all entrypoint names of the model.
  139. """
  140. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  141. return [
  142. _
  143. for _ in dir(hubmodule)
  144. if not _.startswith("__") and callable(getattr(hubmodule, _))
  145. ]
  146. def load(
  147. repo_info: str,
  148. entry: str,
  149. *args,
  150. git_host: str = DEFAULT_GIT_HOST,
  151. use_cache: bool = True,
  152. commit: str = None,
  153. protocol: str = DEFAULT_PROTOCOL,
  154. **kwargs
  155. ) -> Any:
  156. """
  157. Loads model from github or gitlab repo, with pretrained weights.
  158. :param repo_info:
  159. a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  160. tag/branch. The default branch is ``master`` if not specified.
  161. Example: ``"brain_sdk/MegBrain[:hub]"``
  162. :param entry:
  163. an entrypoint defined in hubconf.
  164. :param git_host:
  165. host address of git repo.
  166. Example: github.com
  167. :param use_cache:
  168. whether to use locally cached code or completely re-fetch.
  169. :param commit:
  170. commit id on github or gitlab.
  171. :param protocol:
  172. which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  173. The value should be one of HTTPS, SSH.
  174. :return:
  175. a single model with corresponding pretrained weights.
  176. """
  177. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  178. if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
  179. raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
  180. _check_dependencies(hubmodule)
  181. module = getattr(hubmodule, entry)(*args, **kwargs)
  182. return module
  183. def help(
  184. repo_info: str,
  185. entry: str,
  186. git_host: str = DEFAULT_GIT_HOST,
  187. use_cache: bool = True,
  188. commit: str = None,
  189. protocol: str = DEFAULT_PROTOCOL,
  190. ) -> str:
  191. """
  192. This function returns docstring of entrypoint ``entry`` by following steps:
  193. 1. Pull the repo code specified by git and repo_info.
  194. 2. Load the entry defined in repo's hubconf.py
  195. 3. Return docstring of function entry.
  196. :param repo_info:
  197. a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  198. tag/branch. The default branch is ``master`` if not specified.
  199. Example: ``"brain_sdk/MegBrain[:hub]"``
  200. :param entry:
  201. an entrypoint defined in hubconf.py
  202. :param git_host:
  203. host address of git repo.
  204. Example: github.com
  205. :param use_cache:
  206. whether to use locally cached code or completely re-fetch.
  207. :param commit:
  208. commit id on github or gitlab.
  209. :param protocol:
  210. which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
  211. The value should be one of HTTPS, SSH.
  212. :return:
  213. docstring of entrypoint ``entry``.
  214. """
  215. hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
  216. if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
  217. raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
  218. doc = getattr(hubmodule, entry).__doc__
  219. return doc
  220. def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
  221. """
  222. Loads MegEngine serialized object from the given URL.
  223. If the object is already present in ``model_dir``, it's deserialized and
  224. returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``.
  225. :param url: url to serialized object.
  226. :param model_dir: dir to cache target serialized file.
  227. :return: loaded object.
  228. """
  229. if model_dir is None:
  230. model_dir = os.path.join(_get_megengine_home(), "serialized")
  231. os.makedirs(model_dir, exist_ok=True)
  232. parts = urlparse(url)
  233. filename = os.path.basename(parts.path)
  234. # use hash as prefix to avoid filename conflict from different urls
  235. sha256 = hashlib.sha256()
  236. sha256.update(url.encode())
  237. digest = sha256.hexdigest()[:6]
  238. filename = digest + "_" + filename
  239. cached_file = os.path.join(model_dir, filename)
  240. logger.info(
  241. "load_serialized_obj_from_url: download to or using cached %s", cached_file
  242. )
  243. if not os.path.exists(cached_file):
  244. if is_distributed():
  245. logger.warning(
  246. "Downloading serialized object in DISTRIBUTED mode\n"
  247. " File may be downloaded multiple times. We recommend\n"
  248. " users to download in single process first."
  249. )
  250. download_from_url(url, cached_file, HTTP_READ_TIMEOUT)
  251. state_dict = _mge_load_serialized(cached_file)
  252. return state_dict
  253. class pretrained:
  254. r"""
  255. Decorator which helps to download pretrained weights from the given url.
  256. For example, we can decorate a resnet18 function as follows
  257. .. code-block::
  258. @hub.pretrained("https://url/to/pretrained_resnet18.pkl")
  259. def resnet18(**kwargs):
  260. return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
  261. When decorated function is called with ``pretrained=True``, MegEngine will automatically
  262. download and fill the returned model with pretrained weights.
  263. """
  264. def __init__(self, url):
  265. self.url = url
  266. def __call__(self, func):
  267. @functools.wraps(func)
  268. def pretrained_model_func(
  269. pretrained=False, **kwargs
  270. ): # pylint: disable=redefined-outer-name
  271. model = func(**kwargs)
  272. if pretrained:
  273. weights = load_serialized_obj_from_url(self.url)
  274. model.load_state_dict(weights)
  275. return model
  276. return pretrained_model_func
  277. __all__ = [
  278. "list",
  279. "load",
  280. "help",
  281. "load_serialized_obj_from_url",
  282. "pretrained",
  283. "import_module",
  284. ]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台