You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fetcher.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import hashlib
  10. import os
  11. import re
  12. import shutil
  13. import subprocess
  14. from tempfile import NamedTemporaryFile
  15. from typing import Tuple
  16. from zipfile import ZipFile
  17. import requests
  18. from tqdm import tqdm
  19. from megengine import __version__
  20. from megengine.utils.http_download import (
  21. CHUNK_SIZE,
  22. HTTP_CONNECTION_TIMEOUT,
  23. HTTPDownloadError,
  24. )
  25. from ..distributed import is_distributed, synchronized
  26. from ..logger import get_logger
  27. from .const import DEFAULT_BRANCH_NAME, HTTP_READ_TIMEOUT
  28. from .exceptions import GitCheckoutError, GitPullError, InvalidGitHost, InvalidRepo
  29. from .tools import cd
  30. logger = get_logger(__name__)
  31. HTTP_TIMEOUT = (HTTP_CONNECTION_TIMEOUT, HTTP_READ_TIMEOUT)
  32. pattern = re.compile(
  33. r"^(?:[a-z0-9]" # First character of the domain
  34. r"(?:[a-z0-9-_]{0,61}[a-z0-9])?\.)" # Sub domain + hostname
  35. r"+[a-z0-9][a-z0-9-_]{0,61}" # First 61 characters of the gTLD
  36. r"[a-z]$" # Last character of the gTLD
  37. )
  38. class RepoFetcherBase:
  39. @classmethod
  40. def fetch(
  41. cls,
  42. git_host: str,
  43. repo_info: str,
  44. use_cache: bool = False,
  45. commit: str = None,
  46. silent: bool = True,
  47. ) -> str:
  48. raise NotImplementedError()
  49. @classmethod
  50. def _parse_repo_info(cls, repo_info: str) -> Tuple[str, str, str]:
  51. try:
  52. branch_info = DEFAULT_BRANCH_NAME
  53. if ":" in repo_info:
  54. prefix_info, branch_info = repo_info.split(":")
  55. else:
  56. prefix_info = repo_info
  57. repo_owner, repo_name = prefix_info.split("/")
  58. return repo_owner, repo_name, branch_info
  59. except ValueError:
  60. raise InvalidRepo("repo_info: '{}' is invalid.".format(repo_info))
  61. @classmethod
  62. def _check_git_host(cls, git_host):
  63. return cls._is_valid_domain(git_host) or cls._is_valid_host(git_host)
  64. @classmethod
  65. def _is_valid_domain(cls, s):
  66. try:
  67. return pattern.match(s.encode("idna").decode("ascii"))
  68. except UnicodeError:
  69. return False
  70. @classmethod
  71. def _is_valid_host(cls, s):
  72. nums = s.split(".")
  73. if len(nums) != 4 or any(not _.isdigit() for _ in nums):
  74. return False
  75. return all(0 <= int(_) < 256 for _ in nums)
  76. @classmethod
  77. def _gen_repo_dir(cls, repo_dir: str) -> str:
  78. return hashlib.sha1(repo_dir.encode()).hexdigest()[:16]
  79. class GitSSHFetcher(RepoFetcherBase):
  80. @classmethod
  81. @synchronized
  82. def fetch(
  83. cls,
  84. git_host: str,
  85. repo_info: str,
  86. use_cache: bool = False,
  87. commit: str = None,
  88. silent: bool = True,
  89. ) -> str:
  90. """
  91. Fetches git repo by SSH protocol
  92. :param git_host:
  93. host address of git repo.
  94. Example: github.com
  95. :param repo_info:
  96. a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  97. tag/branch. The default branch is ``master`` if not specified.
  98. Example: ``"brain_sdk/MegBrain[:hub]"``
  99. :param use_cache:
  100. whether to use locally fetched code or completely re-fetch.
  101. :param commit:
  102. commit id on github or gitlab.
  103. :param silent:
  104. whether to accept the stdout and stderr of the subprocess with PIPE, instead of
  105. displaying on the screen.
  106. :return:
  107. directory where the repo code is stored.
  108. """
  109. if not cls._check_git_host(git_host):
  110. raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host))
  111. repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info)
  112. normalized_branch_info = branch_info.replace("/", "_")
  113. repo_dir_raw = "{}_{}_{}".format(
  114. repo_owner, repo_name, normalized_branch_info
  115. ) + ("_{}".format(commit) if commit else "")
  116. repo_dir = (
  117. "_".join(__version__.split(".")) + "_" + cls._gen_repo_dir(repo_dir_raw)
  118. )
  119. git_url = "git@{}:{}/{}.git".format(git_host, repo_owner, repo_name)
  120. if use_cache and os.path.exists(repo_dir): # use cache
  121. logger.debug("Cache Found in %s", repo_dir)
  122. return repo_dir
  123. if is_distributed():
  124. logger.warning(
  125. "When using `hub.load` or `hub.list` to fetch git repositories\n"
  126. " in DISTRIBUTED mode for the first time, processes are synchronized to\n"
  127. " ensure that target repository is ready to use for each process.\n"
  128. " Users are expected to see this warning no more than ONCE, otherwise\n"
  129. " (very little chance) you may need to remove corrupt cache\n"
  130. " `%s` and fetch again.",
  131. repo_dir,
  132. )
  133. shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache
  134. logger.debug(
  135. "Git Clone from Repo:%s Branch: %s to %s",
  136. git_url,
  137. normalized_branch_info,
  138. repo_dir,
  139. )
  140. kwargs = (
  141. {"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {}
  142. )
  143. if commit is None:
  144. # shallow clone repo by branch/tag
  145. p = subprocess.Popen(
  146. [
  147. "git",
  148. "clone",
  149. "-b",
  150. normalized_branch_info,
  151. git_url,
  152. repo_dir,
  153. "--depth=1",
  154. ],
  155. **kwargs,
  156. )
  157. cls._check_clone_pipe(p)
  158. else:
  159. # clone repo and checkout to commit_id
  160. p = subprocess.Popen(["git", "clone", git_url, repo_dir], **kwargs)
  161. cls._check_clone_pipe(p)
  162. with cd(repo_dir):
  163. logger.debug("git checkout to %s", commit)
  164. p = subprocess.Popen(["git", "checkout", commit], **kwargs)
  165. _, err = p.communicate()
  166. if p.returncode:
  167. shutil.rmtree(repo_dir, ignore_errors=True)
  168. raise GitCheckoutError(
  169. "Git checkout error, please check the commit id.\n"
  170. + err.decode()
  171. )
  172. with cd(repo_dir):
  173. shutil.rmtree(".git")
  174. return repo_dir
  175. @classmethod
  176. def _check_clone_pipe(cls, p):
  177. _, err = p.communicate()
  178. if p.returncode:
  179. raise GitPullError(
  180. "Repo pull error, please check repo info.\n" + err.decode()
  181. )
  182. class GitHTTPSFetcher(RepoFetcherBase):
  183. @classmethod
  184. @synchronized
  185. def fetch(
  186. cls,
  187. git_host: str,
  188. repo_info: str,
  189. use_cache: bool = False,
  190. commit: str = None,
  191. silent: bool = True,
  192. ) -> str:
  193. """
  194. Fetches git repo by HTTPS protocol.
  195. :param git_host:
  196. host address of git repo.
  197. Example: github.com
  198. :param repo_info:
  199. a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  200. tag/branch. The default branch is ``master`` if not specified.
  201. Example: ``"brain_sdk/MegBrain[:hub]"``
  202. :param use_cache:
  203. whether to use locally cached code or completely re-fetch.
  204. :param commit:
  205. commit id on github or gitlab.
  206. :param silent:
  207. whether to accept the stdout and stderr of the subprocess with PIPE, instead of
  208. displaying on the screen.
  209. :return:
  210. directory where the repo code is stored.
  211. """
  212. if not cls._check_git_host(git_host):
  213. raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host))
  214. repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info)
  215. normalized_branch_info = branch_info.replace("/", "_")
  216. repo_dir_raw = "{}_{}_{}".format(
  217. repo_owner, repo_name, normalized_branch_info
  218. ) + ("_{}".format(commit) if commit else "")
  219. repo_dir = (
  220. "_".join(__version__.split(".")) + "_" + cls._gen_repo_dir(repo_dir_raw)
  221. )
  222. archive_url = cls._git_archive_link(
  223. git_host, repo_owner, repo_name, branch_info, commit
  224. )
  225. if use_cache and os.path.exists(repo_dir): # use cache
  226. logger.debug("Cache Found in %s", repo_dir)
  227. return repo_dir
  228. if is_distributed():
  229. logger.warning(
  230. "When using `hub.load` or `hub.list` to fetch git repositories "
  231. "in DISTRIBUTED mode for the first time, processes are synchronized to "
  232. "ensure that target repository is ready to use for each process.\n"
  233. "Users are expected to see this warning no more than ONCE, otherwise"
  234. "(very little chance) you may need to remove corrupt hub cache %s and fetch again."
  235. )
  236. shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache
  237. logger.debug("Downloading from %s to %s", archive_url, repo_dir)
  238. cls._download_zip_and_extract(archive_url, repo_dir)
  239. return repo_dir
  240. @classmethod
  241. def _download_zip_and_extract(cls, url, target_dir):
  242. resp = requests.get(url, timeout=HTTP_TIMEOUT, stream=True)
  243. if resp.status_code != 200:
  244. raise HTTPDownloadError(
  245. "An error occured when downloading from {}".format(url)
  246. )
  247. total_size = int(resp.headers.get("Content-Length", 0))
  248. _bar = tqdm(total=total_size, unit="iB", unit_scale=True)
  249. with NamedTemporaryFile("w+b") as f:
  250. for chunk in resp.iter_content(CHUNK_SIZE):
  251. if not chunk:
  252. break
  253. _bar.update(len(chunk))
  254. f.write(chunk)
  255. _bar.close()
  256. f.seek(0)
  257. with ZipFile(f) as temp_zip_f:
  258. zip_dir_name = temp_zip_f.namelist()[0].split("/")[0]
  259. temp_zip_f.extractall(".")
  260. shutil.move(zip_dir_name, target_dir)
  261. @classmethod
  262. def _git_archive_link(cls, git_host, repo_owner, repo_name, branch_info, commit):
  263. archive_link = "https://{}/{}/{}/archive/{}.zip".format(
  264. git_host, repo_owner, repo_name, commit or branch_info
  265. )
  266. return archive_link

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台