You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fetcher.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import hashlib
  10. import os
  11. import re
  12. import shutil
  13. import subprocess
  14. from tempfile import NamedTemporaryFile
  15. from typing import Tuple
  16. from zipfile import ZipFile
  17. import requests
  18. from tqdm import tqdm
  19. from megengine import __version__
  20. from megengine.utils.http_download import (
  21. CHUNK_SIZE,
  22. HTTP_CONNECTION_TIMEOUT,
  23. HTTPDownloadError,
  24. )
  25. from ..distributed import is_distributed, synchronized
  26. from ..logger import get_logger
  27. from .const import DEFAULT_BRANCH_NAME, HTTP_READ_TIMEOUT
  28. from .exceptions import GitCheckoutError, GitPullError, InvalidGitHost, InvalidRepo
  29. from .tools import cd
  30. logger = get_logger(__name__)
  31. HTTP_TIMEOUT = (HTTP_CONNECTION_TIMEOUT, HTTP_READ_TIMEOUT)
  32. pattern = re.compile(
  33. r"^(?:[a-z0-9]" # First character of the domain
  34. r"(?:[a-z0-9-_]{0,61}[a-z0-9])?\.)" # Sub domain + hostname
  35. r"+[a-z0-9][a-z0-9-_]{0,61}" # First 61 characters of the gTLD
  36. r"[a-z]$" # Last character of the gTLD
  37. )
  38. class RepoFetcherBase:
  39. @classmethod
  40. def fetch(
  41. cls,
  42. git_host: str,
  43. repo_info: str,
  44. use_cache: bool = False,
  45. commit: str = None,
  46. silent: bool = True,
  47. ) -> str:
  48. raise NotImplementedError()
  49. @classmethod
  50. def _parse_repo_info(cls, repo_info: str) -> Tuple[str, str, str]:
  51. try:
  52. branch_info = DEFAULT_BRANCH_NAME
  53. if ":" in repo_info:
  54. prefix_info, branch_info = repo_info.split(":")
  55. else:
  56. prefix_info = repo_info
  57. repo_owner, repo_name = prefix_info.split("/")
  58. return repo_owner, repo_name, branch_info
  59. except ValueError:
  60. raise InvalidRepo("repo_info: '{}' is invalid.".format(repo_info))
  61. @classmethod
  62. def _check_git_host(cls, git_host):
  63. return cls._is_valid_domain(git_host) or cls._is_valid_host(git_host)
  64. @classmethod
  65. def _is_valid_domain(cls, s):
  66. try:
  67. return pattern.match(s.encode("idna").decode("ascii"))
  68. except UnicodeError:
  69. return False
  70. @classmethod
  71. def _is_valid_host(cls, s):
  72. nums = s.split(".")
  73. if len(nums) != 4 or any(not _.isdigit() for _ in nums):
  74. return False
  75. return all(0 <= int(_) < 256 for _ in nums)
  76. @classmethod
  77. def _gen_repo_dir(cls, repo_dir: str) -> str:
  78. return hashlib.sha1(repo_dir.encode()).hexdigest()[:16]
  79. class GitSSHFetcher(RepoFetcherBase):
  80. @classmethod
  81. @synchronized
  82. def fetch(
  83. cls,
  84. git_host: str,
  85. repo_info: str,
  86. use_cache: bool = False,
  87. commit: str = None,
  88. silent: bool = True,
  89. ) -> str:
  90. """Fetches git repo by SSH protocol
  91. Args:
  92. git_host: host address of git repo. Eg: github.com
  93. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  94. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  95. use_cache: whether to use locally fetched code or completely re-fetch.
  96. commit: commit id on github or gitlab.
  97. silent: whether to accept the stdout and stderr of the subprocess with PIPE, instead of
  98. displaying on the screen.
  99. Returns:
  100. directory where the repo code is stored.
  101. """
  102. if not cls._check_git_host(git_host):
  103. raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host))
  104. repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info)
  105. normalized_branch_info = branch_info.replace("/", "_")
  106. repo_dir_raw = "{}_{}_{}".format(
  107. repo_owner, repo_name, normalized_branch_info
  108. ) + ("_{}".format(commit) if commit else "")
  109. repo_dir = (
  110. "_".join(__version__.split(".")) + "_" + cls._gen_repo_dir(repo_dir_raw)
  111. )
  112. git_url = "git@{}:{}/{}.git".format(git_host, repo_owner, repo_name)
  113. if use_cache and os.path.exists(repo_dir): # use cache
  114. logger.debug("Cache Found in %s", repo_dir)
  115. return repo_dir
  116. if is_distributed():
  117. logger.warning(
  118. "When using `hub.load` or `hub.list` to fetch git repositories\n"
  119. " in DISTRIBUTED mode for the first time, processes are synchronized to\n"
  120. " ensure that target repository is ready to use for each process.\n"
  121. " Users are expected to see this warning no more than ONCE, otherwise\n"
  122. " (very little chance) you may need to remove corrupt cache\n"
  123. " `%s` and fetch again.",
  124. repo_dir,
  125. )
  126. shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache
  127. logger.debug(
  128. "Git Clone from Repo:%s Branch: %s to %s",
  129. git_url,
  130. normalized_branch_info,
  131. repo_dir,
  132. )
  133. kwargs = (
  134. {"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {}
  135. )
  136. if commit is None:
  137. # shallow clone repo by branch/tag
  138. p = subprocess.Popen(
  139. [
  140. "git",
  141. "clone",
  142. "-b",
  143. normalized_branch_info,
  144. git_url,
  145. repo_dir,
  146. "--depth=1",
  147. ],
  148. **kwargs,
  149. )
  150. cls._check_clone_pipe(p)
  151. else:
  152. # clone repo and checkout to commit_id
  153. p = subprocess.Popen(["git", "clone", git_url, repo_dir], **kwargs)
  154. cls._check_clone_pipe(p)
  155. with cd(repo_dir):
  156. logger.debug("git checkout to %s", commit)
  157. p = subprocess.Popen(["git", "checkout", commit], **kwargs)
  158. _, err = p.communicate()
  159. if p.returncode:
  160. shutil.rmtree(repo_dir, ignore_errors=True)
  161. raise GitCheckoutError(
  162. "Git checkout error, please check the commit id.\n"
  163. + err.decode()
  164. )
  165. with cd(repo_dir):
  166. shutil.rmtree(".git")
  167. return repo_dir
  168. @classmethod
  169. def _check_clone_pipe(cls, p):
  170. _, err = p.communicate()
  171. if p.returncode:
  172. raise GitPullError(
  173. "Repo pull error, please check repo info.\n" + err.decode()
  174. )
  175. class GitHTTPSFetcher(RepoFetcherBase):
  176. @classmethod
  177. @synchronized
  178. def fetch(
  179. cls,
  180. git_host: str,
  181. repo_info: str,
  182. use_cache: bool = False,
  183. commit: str = None,
  184. silent: bool = True,
  185. ) -> str:
  186. """Fetches git repo by HTTPS protocol.
  187. Args:
  188. git_host: host address of git repo. Eg: github.com
  189. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  190. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  191. use_cache: whether to use locally cached code or completely re-fetch.
  192. commit: commit id on github or gitlab.
  193. silent: whether to accept the stdout and stderr of the subprocess with PIPE, instead of
  194. displaying on the screen.
  195. Returns:
  196. directory where the repo code is stored.
  197. """
  198. if not cls._check_git_host(git_host):
  199. raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host))
  200. repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info)
  201. normalized_branch_info = branch_info.replace("/", "_")
  202. repo_dir_raw = "{}_{}_{}".format(
  203. repo_owner, repo_name, normalized_branch_info
  204. ) + ("_{}".format(commit) if commit else "")
  205. repo_dir = (
  206. "_".join(__version__.split(".")) + "_" + cls._gen_repo_dir(repo_dir_raw)
  207. )
  208. archive_url = cls._git_archive_link(
  209. git_host, repo_owner, repo_name, branch_info, commit
  210. )
  211. if use_cache and os.path.exists(repo_dir): # use cache
  212. logger.debug("Cache Found in %s", repo_dir)
  213. return repo_dir
  214. if is_distributed():
  215. logger.warning(
  216. "When using `hub.load` or `hub.list` to fetch git repositories "
  217. "in DISTRIBUTED mode for the first time, processes are synchronized to "
  218. "ensure that target repository is ready to use for each process.\n"
  219. "Users are expected to see this warning no more than ONCE, otherwise"
  220. "(very little chance) you may need to remove corrupt hub cache %s and fetch again."
  221. )
  222. shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache
  223. logger.debug("Downloading from %s to %s", archive_url, repo_dir)
  224. cls._download_zip_and_extract(archive_url, repo_dir)
  225. return repo_dir
  226. @classmethod
  227. def _download_zip_and_extract(cls, url, target_dir):
  228. resp = requests.get(url, timeout=HTTP_TIMEOUT, stream=True)
  229. if resp.status_code != 200:
  230. raise HTTPDownloadError(
  231. "An error occured when downloading from {}".format(url)
  232. )
  233. total_size = int(resp.headers.get("Content-Length", 0))
  234. _bar = tqdm(total=total_size, unit="iB", unit_scale=True)
  235. with NamedTemporaryFile("w+b") as f:
  236. for chunk in resp.iter_content(CHUNK_SIZE):
  237. if not chunk:
  238. break
  239. _bar.update(len(chunk))
  240. f.write(chunk)
  241. _bar.close()
  242. f.seek(0)
  243. with ZipFile(f) as temp_zip_f:
  244. zip_dir_name = temp_zip_f.namelist()[0].split("/")[0]
  245. temp_zip_f.extractall(".")
  246. shutil.move(zip_dir_name, target_dir)
  247. @classmethod
  248. def _git_archive_link(cls, git_host, repo_owner, repo_name, branch_info, commit):
  249. archive_link = "https://{}/{}/{}/archive/{}.zip".format(
  250. git_host, repo_owner, repo_name, commit or branch_info
  251. )
  252. return archive_link