You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fetcher.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # -*- coding: utf-8 -*-
  2. import hashlib
  3. import os
  4. import re
  5. import shutil
  6. import subprocess
  7. from tempfile import NamedTemporaryFile
  8. from typing import Tuple
  9. from zipfile import ZipFile
  10. import requests
  11. from tqdm import tqdm
  12. from megengine import __version__
  13. from megengine.utils.http_download import (
  14. CHUNK_SIZE,
  15. HTTP_CONNECTION_TIMEOUT,
  16. HTTPDownloadError,
  17. )
  18. from ..distributed import is_distributed, synchronized
  19. from ..logger import get_logger
  20. from .const import DEFAULT_BRANCH_NAME, HTTP_READ_TIMEOUT
  21. from .exceptions import GitCheckoutError, GitPullError, InvalidGitHost, InvalidRepo
  22. from .tools import cd
  23. logger = get_logger(__name__)
  24. HTTP_TIMEOUT = (HTTP_CONNECTION_TIMEOUT, HTTP_READ_TIMEOUT)
  25. pattern = re.compile(
  26. r"^(?:[a-z0-9]" # First character of the domain
  27. r"(?:[a-z0-9-_]{0,61}[a-z0-9])?\.)" # Sub domain + hostname
  28. r"+[a-z0-9][a-z0-9-_]{0,61}" # First 61 characters of the gTLD
  29. r"[a-z]$" # Last character of the gTLD
  30. )
  31. class RepoFetcherBase:
  32. @classmethod
  33. def fetch(
  34. cls,
  35. git_host: str,
  36. repo_info: str,
  37. use_cache: bool = False,
  38. commit: str = None,
  39. silent: bool = True,
  40. ) -> str:
  41. raise NotImplementedError()
  42. @classmethod
  43. def _parse_repo_info(cls, repo_info: str) -> Tuple[str, str, str]:
  44. try:
  45. branch_info = DEFAULT_BRANCH_NAME
  46. if ":" in repo_info:
  47. prefix_info, branch_info = repo_info.split(":")
  48. else:
  49. prefix_info = repo_info
  50. repo_owner, repo_name = prefix_info.split("/")
  51. return repo_owner, repo_name, branch_info
  52. except ValueError:
  53. raise InvalidRepo("repo_info: '{}' is invalid.".format(repo_info))
  54. @classmethod
  55. def _check_git_host(cls, git_host):
  56. return cls._is_valid_domain(git_host) or cls._is_valid_host(git_host)
  57. @classmethod
  58. def _is_valid_domain(cls, s):
  59. try:
  60. return pattern.match(s.encode("idna").decode("ascii"))
  61. except UnicodeError:
  62. return False
  63. @classmethod
  64. def _is_valid_host(cls, s):
  65. nums = s.split(".")
  66. if len(nums) != 4 or any(not _.isdigit() for _ in nums):
  67. return False
  68. return all(0 <= int(_) < 256 for _ in nums)
  69. @classmethod
  70. def _gen_repo_dir(cls, repo_dir: str) -> str:
  71. return hashlib.sha1(repo_dir.encode()).hexdigest()[:16]
  72. class GitSSHFetcher(RepoFetcherBase):
  73. @classmethod
  74. @synchronized
  75. def fetch(
  76. cls,
  77. git_host: str,
  78. repo_info: str,
  79. use_cache: bool = False,
  80. commit: str = None,
  81. silent: bool = True,
  82. ) -> str:
  83. """Fetches git repo by SSH protocol
  84. Args:
  85. git_host: host address of git repo. Eg: github.com
  86. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  87. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  88. use_cache: whether to use locally fetched code or completely re-fetch.
  89. commit: commit id on github or gitlab.
  90. silent: whether to accept the stdout and stderr of the subprocess with PIPE, instead of
  91. displaying on the screen.
  92. Returns:
  93. directory where the repo code is stored.
  94. """
  95. if not cls._check_git_host(git_host):
  96. raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host))
  97. repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info)
  98. normalized_branch_info = branch_info.replace("/", "_")
  99. repo_dir_raw = "{}_{}_{}".format(
  100. repo_owner, repo_name, normalized_branch_info
  101. ) + ("_{}".format(commit) if commit else "")
  102. repo_dir = (
  103. "_".join(__version__.split(".")) + "_" + cls._gen_repo_dir(repo_dir_raw)
  104. )
  105. git_url = "git@{}:{}/{}.git".format(git_host, repo_owner, repo_name)
  106. if use_cache and os.path.exists(repo_dir): # use cache
  107. logger.debug("Cache Found in %s", repo_dir)
  108. return repo_dir
  109. if is_distributed():
  110. logger.warning(
  111. "When using `hub.load` or `hub.list` to fetch git repositories\n"
  112. " in DISTRIBUTED mode for the first time, processes are synchronized to\n"
  113. " ensure that target repository is ready to use for each process.\n"
  114. " Users are expected to see this warning no more than ONCE, otherwise\n"
  115. " (very little chance) you may need to remove corrupt cache\n"
  116. " `%s` and fetch again.",
  117. repo_dir,
  118. )
  119. shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache
  120. logger.debug(
  121. "Git Clone from Repo:%s Branch: %s to %s",
  122. git_url,
  123. normalized_branch_info,
  124. repo_dir,
  125. )
  126. kwargs = (
  127. {"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {}
  128. )
  129. if commit is None:
  130. # shallow clone repo by branch/tag
  131. p = subprocess.Popen(
  132. [
  133. "git",
  134. "clone",
  135. "-b",
  136. normalized_branch_info,
  137. git_url,
  138. repo_dir,
  139. "--depth=1",
  140. ],
  141. **kwargs,
  142. )
  143. cls._check_clone_pipe(p)
  144. else:
  145. # clone repo and checkout to commit_id
  146. p = subprocess.Popen(["git", "clone", git_url, repo_dir], **kwargs)
  147. cls._check_clone_pipe(p)
  148. with cd(repo_dir):
  149. logger.debug("git checkout to %s", commit)
  150. p = subprocess.Popen(["git", "checkout", commit], **kwargs)
  151. _, err = p.communicate()
  152. if p.returncode:
  153. shutil.rmtree(repo_dir, ignore_errors=True)
  154. raise GitCheckoutError(
  155. "Git checkout error, please check the commit id.\n"
  156. + err.decode()
  157. )
  158. with cd(repo_dir):
  159. shutil.rmtree(".git")
  160. return repo_dir
  161. @classmethod
  162. def _check_clone_pipe(cls, p):
  163. _, err = p.communicate()
  164. if p.returncode:
  165. raise GitPullError(
  166. "Repo pull error, please check repo info.\n" + err.decode()
  167. )
  168. class GitHTTPSFetcher(RepoFetcherBase):
  169. @classmethod
  170. @synchronized
  171. def fetch(
  172. cls,
  173. git_host: str,
  174. repo_info: str,
  175. use_cache: bool = False,
  176. commit: str = None,
  177. silent: bool = True,
  178. ) -> str:
  179. """Fetches git repo by HTTPS protocol.
  180. Args:
  181. git_host: host address of git repo. Eg: github.com
  182. repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  183. tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
  184. use_cache: whether to use locally cached code or completely re-fetch.
  185. commit: commit id on github or gitlab.
  186. silent: whether to accept the stdout and stderr of the subprocess with PIPE, instead of
  187. displaying on the screen.
  188. Returns:
  189. directory where the repo code is stored.
  190. """
  191. if not cls._check_git_host(git_host):
  192. raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host))
  193. repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info)
  194. normalized_branch_info = branch_info.replace("/", "_")
  195. repo_dir_raw = "{}_{}_{}".format(
  196. repo_owner, repo_name, normalized_branch_info
  197. ) + ("_{}".format(commit) if commit else "")
  198. repo_dir = (
  199. "_".join(__version__.split(".")) + "_" + cls._gen_repo_dir(repo_dir_raw)
  200. )
  201. archive_url = cls._git_archive_link(
  202. git_host, repo_owner, repo_name, branch_info, commit
  203. )
  204. if use_cache and os.path.exists(repo_dir): # use cache
  205. logger.debug("Cache Found in %s", repo_dir)
  206. return repo_dir
  207. if is_distributed():
  208. logger.warning(
  209. "When using `hub.load` or `hub.list` to fetch git repositories "
  210. "in DISTRIBUTED mode for the first time, processes are synchronized to "
  211. "ensure that target repository is ready to use for each process.\n"
  212. "Users are expected to see this warning no more than ONCE, otherwise"
  213. "(very little chance) you may need to remove corrupt hub cache %s and fetch again."
  214. )
  215. shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache
  216. logger.debug("Downloading from %s to %s", archive_url, repo_dir)
  217. cls._download_zip_and_extract(archive_url, repo_dir)
  218. return repo_dir
  219. @classmethod
  220. def _download_zip_and_extract(cls, url, target_dir):
  221. resp = requests.get(url, timeout=HTTP_TIMEOUT, stream=True)
  222. if resp.status_code != 200:
  223. raise HTTPDownloadError(
  224. "An error occured when downloading from {}".format(url)
  225. )
  226. total_size = int(resp.headers.get("Content-Length", 0))
  227. _bar = tqdm(total=total_size, unit="iB", unit_scale=True)
  228. with NamedTemporaryFile("w+b") as f:
  229. for chunk in resp.iter_content(CHUNK_SIZE):
  230. if not chunk:
  231. break
  232. _bar.update(len(chunk))
  233. f.write(chunk)
  234. _bar.close()
  235. f.seek(0)
  236. with ZipFile(f) as temp_zip_f:
  237. zip_dir_name = temp_zip_f.namelist()[0].split("/")[0]
  238. temp_zip_f.extractall(".")
  239. shutil.move(zip_dir_name, target_dir)
  240. @classmethod
  241. def _git_archive_link(cls, git_host, repo_owner, repo_name, branch_info, commit):
  242. archive_link = "https://{}/{}/{}/archive/{}.zip".format(
  243. git_host, repo_owner, repo_name, commit or branch_info
  244. )
  245. return archive_link