You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fetcher.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import hashlib
  10. import os
  11. import re
  12. import shutil
  13. import subprocess
  14. from tempfile import NamedTemporaryFile
  15. from typing import Tuple
  16. from zipfile import ZipFile
  17. import requests
  18. from tqdm import tqdm
  19. from megengine.utils.http_download import (
  20. CHUNK_SIZE,
  21. HTTP_CONNECTION_TIMEOUT,
  22. HTTPDownloadError,
  23. )
  24. from ..distributed import is_distributed, synchronized
  25. from ..logger import get_logger
  26. from .const import DEFAULT_BRANCH_NAME, HTTP_READ_TIMEOUT
  27. from .exceptions import GitCheckoutError, GitPullError, InvalidGitHost, InvalidRepo
  28. from .tools import cd
  29. logger = get_logger(__name__)
  30. HTTP_TIMEOUT = (HTTP_CONNECTION_TIMEOUT, HTTP_READ_TIMEOUT)
  31. pattern = re.compile(
  32. r"^(?:[a-z0-9]" # First character of the domain
  33. r"(?:[a-z0-9-_]{0,61}[a-z0-9])?\.)" # Sub domain + hostname
  34. r"+[a-z0-9][a-z0-9-_]{0,61}" # First 61 characters of the gTLD
  35. r"[a-z]$" # Last character of the gTLD
  36. )
  37. class RepoFetcherBase:
  38. @classmethod
  39. def fetch(
  40. cls,
  41. git_host: str,
  42. repo_info: str,
  43. use_cache: bool = False,
  44. commit: str = None,
  45. silent: bool = True,
  46. ) -> str:
  47. raise NotImplementedError()
  48. @classmethod
  49. def _parse_repo_info(cls, repo_info: str) -> Tuple[str, str, str]:
  50. try:
  51. branch_info = DEFAULT_BRANCH_NAME
  52. if ":" in repo_info:
  53. prefix_info, branch_info = repo_info.split(":")
  54. else:
  55. prefix_info = repo_info
  56. repo_owner, repo_name = prefix_info.split("/")
  57. return repo_owner, repo_name, branch_info
  58. except ValueError:
  59. raise InvalidRepo("repo_info: '{}' is invalid.".format(repo_info))
  60. @classmethod
  61. def _check_git_host(cls, git_host):
  62. return cls._is_valid_domain(git_host) or cls._is_valid_host(git_host)
  63. @classmethod
  64. def _is_valid_domain(cls, s):
  65. try:
  66. return pattern.match(s.encode("idna").decode("ascii"))
  67. except UnicodeError:
  68. return False
  69. @classmethod
  70. def _is_valid_host(cls, s):
  71. nums = s.split(".")
  72. if len(nums) != 4 or any(not _.isdigit() for _ in nums):
  73. return False
  74. return all(0 <= int(_) < 256 for _ in nums)
  75. @classmethod
  76. def _gen_repo_dir(cls, repo_dir: str) -> str:
  77. return hashlib.sha1(repo_dir.encode()).hexdigest()[:16]
  78. class GitSSHFetcher(RepoFetcherBase):
  79. @classmethod
  80. @synchronized
  81. def fetch(
  82. cls,
  83. git_host: str,
  84. repo_info: str,
  85. use_cache: bool = False,
  86. commit: str = None,
  87. silent: bool = True,
  88. ) -> str:
  89. """
  90. Fetches git repo by SSH protocol
  91. :param git_host:
  92. host address of git repo.
  93. Example: github.com
  94. :param repo_info:
  95. a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  96. tag/branch. The default branch is ``master`` if not specified.
  97. Example: ``"brain_sdk/MegBrain[:hub]"``
  98. :param use_cache:
  99. whether to use locally fetched code or completely re-fetch.
  100. :param commit:
  101. commit id on github or gitlab.
  102. :param silent:
  103. whether to accept the stdout and stderr of the subprocess with PIPE, instead of
  104. displaying on the screen.
  105. :return:
  106. directory where the repo code is stored.
  107. """
  108. if not cls._check_git_host(git_host):
  109. raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host))
  110. repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info)
  111. normalized_branch_info = branch_info.replace("/", "_")
  112. repo_dir_raw = "{}_{}_{}".format(
  113. repo_owner, repo_name, normalized_branch_info
  114. ) + ("_{}".format(commit) if commit else "")
  115. repo_dir = cls._gen_repo_dir(repo_dir_raw)
  116. git_url = "git@{}:{}/{}.git".format(git_host, repo_owner, repo_name)
  117. if use_cache and os.path.exists(repo_dir): # use cache
  118. logger.debug("Cache Found in %s", repo_dir)
  119. return repo_dir
  120. if is_distributed():
  121. logger.warning(
  122. "When using `hub.load` or `hub.list` to fetch git repositories\n"
  123. " in DISTRIBUTED mode for the first time, processes are synchronized to\n"
  124. " ensure that target repository is ready to use for each process.\n"
  125. " Users are expected to see this warning no more than ONCE, otherwise\n"
  126. " (very little chance) you may need to remove corrupt cache\n"
  127. " `%s` and fetch again.",
  128. repo_dir,
  129. )
  130. shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache
  131. logger.debug(
  132. "Git Clone from Repo:%s Branch: %s to %s",
  133. git_url,
  134. normalized_branch_info,
  135. repo_dir,
  136. )
  137. kwargs = (
  138. {"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {}
  139. )
  140. if commit is None:
  141. # shallow clone repo by branch/tag
  142. p = subprocess.Popen(
  143. [
  144. "git",
  145. "clone",
  146. "-b",
  147. normalized_branch_info,
  148. git_url,
  149. repo_dir,
  150. "--depth=1",
  151. ],
  152. **kwargs,
  153. )
  154. cls._check_clone_pipe(p)
  155. else:
  156. # clone repo and checkout to commit_id
  157. p = subprocess.Popen(["git", "clone", git_url, repo_dir], **kwargs)
  158. cls._check_clone_pipe(p)
  159. with cd(repo_dir):
  160. logger.debug("git checkout to %s", commit)
  161. p = subprocess.Popen(["git", "checkout", commit], **kwargs)
  162. _, err = p.communicate()
  163. if p.returncode:
  164. shutil.rmtree(repo_dir, ignore_errors=True)
  165. raise GitCheckoutError(
  166. "Git checkout error, please check the commit id.\n"
  167. + err.decode()
  168. )
  169. with cd(repo_dir):
  170. shutil.rmtree(".git")
  171. return repo_dir
  172. @classmethod
  173. def _check_clone_pipe(cls, p):
  174. _, err = p.communicate()
  175. if p.returncode:
  176. raise GitPullError(
  177. "Repo pull error, please check repo info.\n" + err.decode()
  178. )
  179. class GitHTTPSFetcher(RepoFetcherBase):
  180. @classmethod
  181. @synchronized
  182. def fetch(
  183. cls,
  184. git_host: str,
  185. repo_info: str,
  186. use_cache: bool = False,
  187. commit: str = None,
  188. silent: bool = True,
  189. ) -> str:
  190. """
  191. Fetches git repo by HTTPS protocol.
  192. :param git_host:
  193. host address of git repo.
  194. Example: github.com
  195. :param repo_info:
  196. a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
  197. tag/branch. The default branch is ``master`` if not specified.
  198. Example: ``"brain_sdk/MegBrain[:hub]"``
  199. :param use_cache:
  200. whether to use locally cached code or completely re-fetch.
  201. :param commit:
  202. commit id on github or gitlab.
  203. :param silent:
  204. whether to accept the stdout and stderr of the subprocess with PIPE, instead of
  205. displaying on the screen.
  206. :return:
  207. directory where the repo code is stored.
  208. """
  209. if not cls._check_git_host(git_host):
  210. raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host))
  211. repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info)
  212. normalized_branch_info = branch_info.replace("/", "_")
  213. repo_dir_raw = "{}_{}_{}".format(
  214. repo_owner, repo_name, normalized_branch_info
  215. ) + ("_{}".format(commit) if commit else "")
  216. repo_dir = cls._gen_repo_dir(repo_dir_raw)
  217. archive_url = cls._git_archive_link(
  218. git_host, repo_owner, repo_name, branch_info, commit
  219. )
  220. if use_cache and os.path.exists(repo_dir): # use cache
  221. logger.debug("Cache Found in %s", repo_dir)
  222. return repo_dir
  223. if is_distributed():
  224. logger.warning(
  225. "When using `hub.load` or `hub.list` to fetch git repositories "
  226. "in DISTRIBUTED mode for the first time, processes are synchronized to "
  227. "ensure that target repository is ready to use for each process.\n"
  228. "Users are expected to see this warning no more than ONCE, otherwise"
  229. "(very little chance) you may need to remove corrupt hub cache %s and fetch again."
  230. )
  231. shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache
  232. logger.debug("Downloading from %s to %s", archive_url, repo_dir)
  233. cls._download_zip_and_extract(archive_url, repo_dir)
  234. return repo_dir
  235. @classmethod
  236. def _download_zip_and_extract(cls, url, target_dir):
  237. resp = requests.get(url, timeout=HTTP_TIMEOUT, stream=True)
  238. if resp.status_code != 200:
  239. raise HTTPDownloadError(
  240. "An error occured when downloading from {}".format(url)
  241. )
  242. total_size = int(resp.headers.get("Content-Length", 0))
  243. _bar = tqdm(total=total_size, unit="iB", unit_scale=True)
  244. with NamedTemporaryFile("w+b") as f:
  245. for chunk in resp.iter_content(CHUNK_SIZE):
  246. if not chunk:
  247. break
  248. _bar.update(len(chunk))
  249. f.write(chunk)
  250. _bar.close()
  251. f.seek(0)
  252. with ZipFile(f) as temp_zip_f:
  253. zip_dir_name = temp_zip_f.namelist()[0].split("/")[0]
  254. temp_zip_f.extractall(".")
  255. shutil.move(zip_dir_name, target_dir)
  256. @classmethod
  257. def _git_archive_link(cls, git_host, repo_owner, repo_name, branch_info, commit):
  258. archive_link = "https://{}/{}/{}/archive/{}.zip".format(
  259. git_host, repo_owner, repo_name, commit or branch_info
  260. )
  261. return archive_link

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台