You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hub.py 22 kB

resolve construct prediction bug Former-commit-id: b0d35d9aea14dc7b9382642dc4aebbb7b33227af [formerly 28b7582d253ac7370edb4ee77c41a78ec5efe21b] [formerly 632f8af13ad3d37eb5e721730ccef4c795052ff0 [formerly 9eb35832567152e9d8b616951cba2f9d472e3fd0]] [formerly ab6f5a8a9c6b9255cdcd65805974015a7daf0693 [formerly 2720e70c78084e35396f51bb28c2090679f1a53b] [formerly c1f615a3c58a4916b810c6a566c16c12bc507908 [formerly 21840b378759723d0ff5a11d59d3e2daade1e679]]] [formerly 2184b60ad634dbc847bfaeb3413ae99b824bad0a [formerly 2dae6c4ed4a1995848d151ca54fddd69969145a1] [formerly 6a577865b8a1f92a963b3d168e06a63dcb0161b3 [formerly ee28aaed84e9f06385a453ed06a7314fc603ca91]] [formerly 72078d88c8363efb9a91ddd5963d234b8ab0fd10 [formerly ab1fb9d5125467f2bad37c01b395460e15830910] [formerly 476600a47c7aa6c88895eda67185ee6a70d1fd5d [formerly f5a6205bb56bdfcde3584db8107b25af51ce1feb]]]] [formerly e5c2c3deefcac7dea9094877181d7910eb5286a5 [formerly 8585dc41a9c94b9a7084484b60e3cc07b9cef7fe] [formerly 9cc7fe2088bcc893022abf4b6ba137fdbdda4211 [formerly 1d2104316f6dc308482b2d8c1c0210a240ef4f4a]] [formerly b51e614a11931bf92fcbc497ce6509339a4a9771 [formerly f9891a191d390c09ecd274931f6fbab9f09fe06c] [formerly 048aa2f114fe950c50e3d9833dbe4967c65f25b1 [formerly d4de64574be6221a8287ce4c4f72e7b39f361258]]] [formerly ff66f55a512c8e42de6d48182b79b9c8e42bfae1 [formerly 0b691b9a7f54cef7d12d14b5e5859959503838c8] [formerly e64e8dd25314730134bf6de25d3d3c0a9d5c8d63 [formerly cc45939e01d88b193138b72384ba519e4df731e7]] [formerly cdad10712ac93799628bcf24365402a347dbb16d [formerly 2789e20e79f5b9726ad7c1a002a31c5d2253f63b] [formerly 25924f293c4e9d69e44c77a08a05c46d3f66bb27 [formerly 32997accab262d1a8960e38792035422d14efd13]]]]] [formerly f43b4310405ef100174e295d57504653eb26f04f [formerly 95815d02ca055259aa579dc3822f1e77ddca8a1c] [formerly fe9bd45d441e3608622bdfb90d3748bb2fd17bbd [formerly 6daf0aa73e40a4027dd5ae04cf16522974b71b36]] [formerly 61ab30c9a3ab7d8fb4d7366dd09937813e34a7fb [formerly a13c6e23b463cb54a80938f981b811f33326f68e] [formerly 86fa5919eec79d126e750a176f67f1554eac41fa [formerly 1e49e1a30348ac38fe7e31a59a50006aabfac43c]]] [formerly 69c5bc967a3bd09af946c28c8680a510614897fd [formerly ab82915cbd9157ad0dc204c0247107bda650742a] [formerly f8057c3b14e022eb44a0abb79f50196e093eaa35 [formerly 5232f345780a58c7544bd6a4d5c31b930cd0a36e]] [formerly 671c54e9520203b5ed14053770c4a1ed4b621959 [formerly 6454a28f2685a4176ead96d7ee68e15afaaadfa8] [formerly 3db6ff66b90f6abf68a8c78ae929d548c1f6b2ad [formerly aa8c7fe127cf82a244718da550b121052855ad27]]]] [formerly 86b2ec6b84d16ae7d5c51eace4f24ea472ea11a2 [formerly f35c344efebe54ad170cd6de9bc7f4126e53f510] [formerly d5616f66cd90f797ad80a85b38fe9234f86540fe [formerly 98c9dca7daf1efc38766b3fde4854f06538d7efa]] [formerly a7dcc62bc54b16698b3b7e06aa9a6735afd63b08 [formerly 4ef4fa0c98b2f210cbe7de6f2a00ba5c63c1b2ba] [formerly 55f670b9ae053698ddba84658940ddea9ac795e6 [formerly 1cd4421e2e02bfb8597f338583ad32c0550c02af]]] [formerly d7a5bab8320a31eb1dedac6a7806e5c59683053f [formerly c77c5b48df83e1aed383fb28f0812fc87ecbf973] [formerly 01ffd33e2ff1a6fa5cffd6dd82289111682cd9a1 [formerly aea728ceb67c39709f9a3921547a95e09c1dcff7]] [formerly 16afb18e352c8da18fff35384430bbc427cc4802 [formerly 4768b156f37cfba1a657944fe8d4314af3532685] [formerly 3c1298c626a945d362a74f922c5660b11e8ad491 [formerly 1e61cf09743c7ec285301e748dbbec3caf42d89c]]]]]] Former-commit-id: 28b09a56cb8b16b07fc356e5f3133ad2bb9d5859 [formerly 5241dbb36c08ebc2ecfeac6450b1cb3b52f624ea] [formerly d43909d97927eee8f9f09fcc8b4e5c92cddb3a6a [formerly 43b0cca7f5e0372770f2b2a530bc921d8ac9ec8d]] [formerly 8d8f384c8e46748e59790e0421d07bd8a942043d [formerly bcf58203c66d0bc6c5392a855d1b5d617dcef9df] [formerly ca56bff2d0a2b498941b59b240f1cf4ac420462c [formerly 7a8750ffc99115c6cb37e63ce9fd0c3958d235a8]]] [formerly 2ce9fa87aeb4fe174de68dc4d79501b057b477b1 [formerly e7b1b542f5b9fbf69df7912c4393d264713c24f4] [formerly 62a7edf94a8193828f465bd432ffc15edd480ea1 [formerly 26ca5f220dcc3745ee271e5712d72836f2255222]] [formerly 73d10253b93b35f692e13089ddb148d408cf14fe [formerly af463cecb0333e14aabc502126f7c6d22973e214] [formerly 961314f474a483302d2a5ce0a12b51afc386cb6d [formerly 6e7141e5e2e9f72d2102892b4affce0de0944cb0]]]] [formerly c3f8938ba5ba29f2ad8de1cfc486a7f323ecea69 [formerly 7f6674829240eab485a34d67eb41f61ca744b26f] [formerly 3e5ef2c136584b7f0e769f5bdf43ab994fb12c21 [formerly 266ddb4ccc8948b124ef81cee46efd1804bddd2e]] [formerly 4a8a5437b39ea3daf303fa79dbae22c294ff0b40 [formerly cbcb0f87775b068d2fccd5c8632cf05ab0a019a7] [formerly c212b8217fd92fa1b7899ab5211e0262f3a9c927 [formerly 5f3d3d01c87574acb44b077dd7439ac53b98a0f1]]] [formerly 09764ba6cd565eae5cb426f705ec72ee255e4b49 [formerly 4db991be5f3f4d27092a19876eb75c5e51f60bb0] [formerly f79c6ec15df35536aeba92957c37306643e3a517 [formerly 7f8eb54d4752bbefbc208251563af8e77a1792c0]] [formerly 771b00b188ceaa0b9569636bae5f28643999ce47 [formerly f1dcba565fc0943c24d468b496b48df7f815420f] [formerly 83c42510ad4af56394a30b78f2b4eeb693b2df6b [formerly 3c1298c626a945d362a74f922c5660b11e8ad491]]]]] Former-commit-id: f5cdcca4f3e91cc75d3af0de6bd17cce1dfaf021 [formerly dc57f947a2385967e8847ab784a433184f9620a3] [formerly 55cd6eb9a35d44b00d0d55a30c9686dffafdae57 [formerly e92d6c09230f08e9760fc17f45f8389e919b82a3]] [formerly ba80ed43d23d7ff4c349e534fb9036b84064bbf3 [formerly 0a2f65401a73b676569fb809d0e4e9dd28803952] [formerly 73c0c2ebb36b0a14c038f327507a133d082fef6f [formerly bf80cf285eec74221e6add1ff41e5f5c3b3ab808]]] [formerly a34f21d93329c3bee2d5e97fcbe820a0eab1d5df [formerly ae311cb3e77af5bcd03976e2a108ac1eba772fc4] [formerly ae0e3ed079c601e0096c648911ae8246eef220fe [formerly c9030d0303acdbd031b19797c30e98573cf32837]] [formerly f8f4f6a8ec75e269e0db55255d3a01dcafc79c23 [formerly 343dd65df69b51f0f404f200c3af4cbc52320b25] [formerly 1c97e6a7baa6ff68d466723b754b2887cc834be9 [formerly c69d281aca97ee3bde4328d04d92163eab376623]]]] Former-commit-id: b93de572408cfd327f0ca40b8d62530945e5c5ba [formerly 6fd32d0759069e91a1183acb86096ccbab5f88e9] [formerly ff29e71cb0b43473b0be6345b1747e6ec2df856b [formerly ee483b56f98eb5e187b7b262278d121c34ed86ab]] [formerly 2f944ec28b50590b4b4eb1e7aa07518f78234810 [formerly 15c46e806cd9f35403a758aa6f46f3d746b0d0c5] [formerly dca0c18b8b693227cdef4539165e8e2d6a47c340 [formerly f185a8658cf2893c7fa8b0393d52489593d32d37]]] Former-commit-id: 76ed7f184f13ac60e2d458d60e9a8befe3cace6f [formerly c8adbe1dea7c40540fc74be2867a780b10a5788f] [formerly 392fc3e54b90f8d1201d2986ffbde77529b26560 [formerly 5bd396738e6152ff5b4b28bbda4cb495bcda1459]] Former-commit-id: 11579edbe104d2a7361847a5dd71fc06dc4f369c [formerly 6970795314200993d1d027617ceddd7a7c6421a3] Former-commit-id: 8fca90698b3ffff3c183e4754cabc6b39fe9b16f
4 years ago

  1. import errno
  2. import hashlib
  3. import os
  4. import re
  5. import shutil
  6. import sys
  7. import tempfile
  8. # import torch
  9. import warnings
  10. import zipfile
  11. from urllib.request import urlopen, Request
  12. from urllib.parse import urlparse # noqa: F401
  13. try:
  14. from tqdm.auto import tqdm # automatically select proper tqdm submodule if available
  15. except ImportError:
  16. try:
  17. from tqdm import tqdm
  18. except ImportError:
  19. # fake tqdm if it's not installed
  20. class tqdm(object): # type: ignore
  21. def __init__(self, total=None, disable=False,
  22. unit=None, unit_scale=None, unit_divisor=None):
  23. self.total = total
  24. self.disable = disable
  25. self.n = 0
  26. # ignore unit, unit_scale, unit_divisor; they're just for real tqdm
  27. def update(self, n):
  28. if self.disable:
  29. return
  30. self.n += n
  31. if self.total is None:
  32. sys.stderr.write("\r{0:.1f} bytes".format(self.n))
  33. else:
  34. sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
  35. sys.stderr.flush()
  36. def __enter__(self):
  37. return self
  38. def __exit__(self, exc_type, exc_val, exc_tb):
  39. if self.disable:
  40. return
  41. sys.stderr.write('\n')
  42. # # matches bfd8deac from resnet18-bfd8deac.pth
  43. # HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
  44. #
  45. # MASTER_BRANCH = 'master'
  46. # ENV_TORCH_HOME = 'TORCH_HOME'
  47. # ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
  48. # DEFAULT_CACHE_DIR = '~/.cache'
  49. # VAR_DEPENDENCY = 'dependencies'
  50. # MODULE_HUBCONF = 'hubconf.py'
  51. # READ_DATA_CHUNK = 8192
  52. # _hub_dir = None
  53. #
  54. #
  55. # # Copied from tools/shared/module_loader to be included in torch package
  56. # def import_module(name, path):
  57. # import importlib.util
  58. # from importlib.abc import Loader
  59. # spec = importlib.util.spec_from_file_location(name, path)
  60. # module = importlib.util.module_from_spec(spec)
  61. # assert isinstance(spec.loader, Loader)
  62. # spec.loader.exec_module(module)
  63. # return module
  64. #
  65. #
  66. # def _remove_if_exists(path):
  67. # if os.path.exists(path):
  68. # if os.path.isfile(path):
  69. # os.remove(path)
  70. # else:
  71. # shutil.rmtree(path)
  72. #
  73. #
  74. # def _git_archive_link(repo_owner, repo_name, branch):
  75. # return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch)
  76. #
  77. #
  78. # def _load_attr_from_module(module, func_name):
  79. # # Check if callable is defined in the module
  80. # if func_name not in dir(module):
  81. # return None
  82. # return getattr(module, func_name)
  83. #
  84. #
  85. # def _get_torch_home():
  86. # torch_home = os.path.expanduser(
  87. # os.getenv(ENV_TORCH_HOME,
  88. # os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
  89. # DEFAULT_CACHE_DIR), 'torch')))
  90. # return torch_home
  91. #
  92. #
  93. # def _parse_repo_info(github):
  94. # branch = MASTER_BRANCH
  95. # if ':' in github:
  96. # repo_info, branch = github.split(':')
  97. # else:
  98. # repo_info = github
  99. # repo_owner, repo_name = repo_info.split('/')
  100. # return repo_owner, repo_name, branch
  101. #
  102. #
  103. # def _get_cache_or_reload(github, force_reload, verbose=True):
  104. # # Setup hub_dir to save downloaded files
  105. # hub_dir = get_dir()
  106. # if not os.path.exists(hub_dir):
  107. # os.makedirs(hub_dir)
  108. # # Parse github repo information
  109. # repo_owner, repo_name, branch = _parse_repo_info(github)
  110. # # Github allows branch name with slash '/',
  111. # # this causes confusion with path on both Linux and Windows.
  112. # # Backslash is not allowed in Github branch name so no need to
  113. # # to worry about it.
  114. # normalized_br = branch.replace('/', '_')
  115. # # Github renames folder repo-v1.x.x to repo-1.x.x
  116. # # We don't know the repo name before downloading the zip file
  117. # # and inspect name from it.
  118. # # To check if cached repo exists, we need to normalize folder names.
  119. # repo_dir = os.path.join(hub_dir, '_'.join([repo_owner, repo_name, normalized_br]))
  120. #
  121. # use_cache = (not force_reload) and os.path.exists(repo_dir)
  122. #
  123. # if use_cache:
  124. # if verbose:
  125. # sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
  126. # else:
  127. # cached_file = os.path.join(hub_dir, normalized_br + '.zip')
  128. # _remove_if_exists(cached_file)
  129. #
  130. # url = _git_archive_link(repo_owner, repo_name, branch)
  131. # sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file))
  132. # download_url_to_file(url, cached_file, progress=False)
  133. #
  134. # with zipfile.ZipFile(cached_file) as cached_zipfile:
  135. # extraced_repo_name = cached_zipfile.infolist()[0].filename
  136. # extracted_repo = os.path.join(hub_dir, extraced_repo_name)
  137. # _remove_if_exists(extracted_repo)
  138. # # Unzip the code and rename the base folder
  139. # cached_zipfile.extractall(hub_dir)
  140. #
  141. # _remove_if_exists(cached_file)
  142. # _remove_if_exists(repo_dir)
  143. # shutil.move(extracted_repo, repo_dir) # rename the repo
  144. #
  145. # return repo_dir
  146. #
  147. #
  148. # def _check_module_exists(name):
  149. # if sys.version_info >= (3, 4):
  150. # import importlib.util
  151. # return importlib.util.find_spec(name) is not None
  152. # elif sys.version_info >= (3, 3):
  153. # # Special case for python3.3
  154. # import importlib.find_loader
  155. # return importlib.find_loader(name) is not None
  156. # else:
  157. # # NB: Python2.7 imp.find_module() doesn't respect PEP 302,
  158. # # it cannot find a package installed as .egg(zip) file.
  159. # # Here we use workaround from:
  160. # # https://stackoverflow.com/questions/28962344/imp-find-module-which-supports-zipped-eggs?lq=1
  161. # # Also imp doesn't handle hierarchical module names (names contains dots).
  162. # try:
  163. # # 1. Try imp.find_module(), which searches sys.path, but does
  164. # # not respect PEP 302 import hooks.
  165. # import imp
  166. # result = imp.find_module(name)
  167. # if result:
  168. # return True
  169. # except ImportError:
  170. # pass
  171. # path = sys.path
  172. # for item in path:
  173. # # 2. Scan path for import hooks. sys.path_importer_cache maps
  174. # # path items to optional "importer" objects, that implement
  175. # # find_module() etc. Note that path must be a subset of
  176. # # sys.path for this to work.
  177. # importer = sys.path_importer_cache.get(item)
  178. # if importer:
  179. # try:
  180. # result = importer.find_module(name, [item])
  181. # if result:
  182. # return True
  183. # except ImportError:
  184. # pass
  185. # return False
  186. #
  187. # def _check_dependencies(m):
  188. # dependencies = _load_attr_from_module(m, VAR_DEPENDENCY)
  189. #
  190. # if dependencies is not None:
  191. # missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
  192. # if len(missing_deps):
  193. # raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps)))
  194. #
  195. #
  196. # def _load_entry_from_hubconf(m, model):
  197. # if not isinstance(model, str):
  198. # raise ValueError('Invalid input: model should be a string of function name')
  199. #
  200. # # Note that if a missing dependency is imported at top level of hubconf, it will
  201. # # throw before this function. It's a chicken and egg situation where we have to
  202. # # load hubconf to know what're the dependencies, but to import hubconf it requires
  203. # # a missing package. This is fine, Python will throw proper error message for users.
  204. # _check_dependencies(m)
  205. #
  206. # func = _load_attr_from_module(m, model)
  207. #
  208. # if func is None or not callable(func):
  209. # raise RuntimeError('Cannot find callable {} in hubconf'.format(model))
  210. #
  211. # return func
  212. #
  213. #
  214. # def get_dir():
  215. # r"""
  216. # Get the Torch Hub cache directory used for storing downloaded models & weights.
  217. #
  218. # If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where
  219. # environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
  220. # ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
  221. # filesystem layout, with a default value ``~/.cache`` if the environment
  222. # variable is not set.
  223. # """
  224. # # Issue warning to move data if old env is set
  225. # if os.getenv('TORCH_HUB'):
  226. # warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
  227. #
  228. # if _hub_dir is not None:
  229. # return _hub_dir
  230. # return os.path.join(_get_torch_home(), 'hub')
  231. #
  232. #
  233. # def set_dir(d):
  234. # r"""
  235. # Optionally set the Torch Hub directory used to save downloaded models & weights.
  236. #
  237. # Args:
  238. # d (string): path to a local folder to save downloaded models & weights.
  239. # """
  240. # global _hub_dir
  241. # _hub_dir = d
  242. #
  243. #
  244. # def list(github, force_reload=False):
  245. # r"""
  246. # List all entrypoints available in `github` hubconf.
  247. #
  248. # Args:
  249. # github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional
  250. # tag/branch. The default branch is `master` if not specified.
  251. # Example: 'pytorch/vision[:hub]'
  252. # force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
  253. # Default is `False`.
  254. # Returns:
  255. # entrypoints: a list of available entrypoint names
  256. #
  257. # Example:
  258. # >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
  259. # """
  260. # repo_dir = _get_cache_or_reload(github, force_reload, True)
  261. #
  262. # sys.path.insert(0, repo_dir)
  263. #
  264. # hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
  265. #
  266. # sys.path.remove(repo_dir)
  267. #
  268. # # We take functions starts with '_' as internal helper functions
  269. # entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
  270. #
  271. # return entrypoints
  272. #
  273. #
  274. # def help(github, model, force_reload=False):
  275. # r"""
  276. # Show the docstring of entrypoint `model`.
  277. #
  278. # Args:
  279. # github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional
  280. # tag/branch. The default branch is `master` if not specified.
  281. # Example: 'pytorch/vision[:hub]'
  282. # model (string): a string of entrypoint name defined in repo's hubconf.py
  283. # force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
  284. # Default is `False`.
  285. # Example:
  286. # >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
  287. # """
  288. # repo_dir = _get_cache_or_reload(github, force_reload, True)
  289. #
  290. # sys.path.insert(0, repo_dir)
  291. #
  292. # hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
  293. #
  294. # sys.path.remove(repo_dir)
  295. #
  296. # entry = _load_entry_from_hubconf(hub_module, model)
  297. #
  298. # return entry.__doc__
  299. #
  300. #
  301. # # Ideally this should be `def load(github, model, *args, forece_reload=False, **kwargs):`,
  302. # # but Python2 complains syntax error for it. We have to skip force_reload in function
  303. # # signature here but detect it in kwargs instead.
  304. # # TODO: fix it after Python2 EOL
  305. # def load(repo_or_dir, model, *args, **kwargs):
  306. # r"""
  307. # Load a model from a github repo or a local directory.
  308. #
  309. # Note: Loading a model is the typical use case, but this can also be used to
  310. # for loading other objects such as tokenizers, loss functions, etc.
  311. #
  312. # If :attr:`source` is ``'github'``, :attr:`repo_or_dir` is expected to be
  313. # of the form ``repo_owner/repo_name[:tag_name]`` with an optional
  314. # tag/branch.
  315. #
  316. # If :attr:`source` is ``'local'``, :attr:`repo_or_dir` is expected to be a
  317. # path to a local directory.
  318. #
  319. # Args:
  320. # repo_or_dir (string): repo name (``repo_owner/repo_name[:tag_name]``),
  321. # if ``source = 'github'``; or a path to a local directory, if
  322. # ``source = 'local'``.
  323. # model (string): the name of a callable (entrypoint) defined in the
  324. # repo/dir's ``hubconf.py``.
  325. # *args (optional): the corresponding args for callable :attr:`model`.
  326. # source (string, optional): ``'github'`` | ``'local'``. Specifies how
  327. # ``repo_or_dir`` is to be interpreted. Default is ``'github'``.
  328. # force_reload (bool, optional): whether to force a fresh download of
  329. # the github repo unconditionally. Does not have any effect if
  330. # ``source = 'local'``. Default is ``False``.
  331. # verbose (bool, optional): If ``False``, mute messages about hitting
  332. # local caches. Note that the message about first download cannot be
  333. # muted. Does not have any effect if ``source = 'local'``.
  334. # Default is ``True``.
  335. # **kwargs (optional): the corresponding kwargs for callable
  336. # :attr:`model`.
  337. #
  338. # Returns:
  339. # The output of the :attr:`model` callable when called with the given
  340. # ``*args`` and ``**kwargs``.
  341. #
  342. # Example:
  343. # >>> # from a github repo
  344. # >>> repo = 'pytorch/vision'
  345. # >>> model = torch.hub.load(repo, 'resnet50', pretrained=True)
  346. # >>> # from a local directory
  347. # >>> path = '/some/local/path/pytorch/vision'
  348. # >>> model = torch.hub.load(path, 'resnet50', pretrained=True)
  349. # """
  350. # source = kwargs.pop('source', 'github').lower()
  351. # force_reload = kwargs.pop('force_reload', False)
  352. # verbose = kwargs.pop('verbose', True)
  353. #
  354. # if source not in ('github', 'local'):
  355. # raise ValueError(
  356. # f'Unknown source: "{source}". Allowed values: "github" | "local".')
  357. #
  358. # if source == 'github':
  359. # repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose)
  360. #
  361. # model = _load_local(repo_or_dir, model, *args, **kwargs)
  362. # return model
  363. #
  364. #
  365. # def _load_local(hubconf_dir, model, *args, **kwargs):
  366. # r"""
  367. # Load a model from a local directory with a ``hubconf.py``.
  368. #
  369. # Args:
  370. # hubconf_dir (string): path to a local directory that contains a
  371. # ``hubconf.py``.
  372. # model (string): name of an entrypoint defined in the directory's
  373. # `hubconf.py`.
  374. # *args (optional): the corresponding args for callable ``model``.
  375. # **kwargs (optional): the corresponding kwargs for callable ``model``.
  376. #
  377. # Returns:
  378. # a single model with corresponding pretrained weights.
  379. #
  380. # Example:
  381. # >>> path = '/some/local/path/pytorch/vision'
  382. # >>> model = _load_local(path, 'resnet50', pretrained=True)
  383. # """
  384. # sys.path.insert(0, hubconf_dir)
  385. #
  386. # hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)
  387. # hub_module = import_module(MODULE_HUBCONF, hubconf_path)
  388. #
  389. # entry = _load_entry_from_hubconf(hub_module, model)
  390. # model = entry(*args, **kwargs)
  391. #
  392. # sys.path.remove(hubconf_dir)
  393. #
  394. # return model
  395. #
  396. #
  397. # def download_url_to_file(url, dst, hash_prefix=None, progress=True):
  398. # r"""Download object at the given URL to a local path.
  399. #
  400. # Args:
  401. # url (string): URL of the object to download
  402. # dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
  403. # hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`.
  404. # Default: None
  405. # progress (bool, optional): whether or not to display a progress bar to stderr
  406. # Default: True
  407. #
  408. # Example:
  409. # >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
  410. #
  411. # """
  412. # file_size = None
  413. # # We use a different API for python2 since urllib(2) doesn't recognize the CA
  414. # # certificates in older Python
  415. # req = Request(url, headers={"User-Agent": "torch.hub"})
  416. # u = urlopen(req)
  417. # meta = u.info()
  418. # if hasattr(meta, 'getheaders'):
  419. # content_length = meta.getheaders("Content-Length")
  420. # else:
  421. # content_length = meta.get_all("Content-Length")
  422. # if content_length is not None and len(content_length) > 0:
  423. # file_size = int(content_length[0])
  424. #
  425. # # We deliberately save it in a temp file and move it after
  426. # # download is complete. This prevents a local working checkpoint
  427. # # being overridden by a broken download.
  428. # dst = os.path.expanduser(dst)
  429. # dst_dir = os.path.dirname(dst)
  430. # f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
  431. #
  432. # try:
  433. # if hash_prefix is not None:
  434. # sha256 = hashlib.sha256()
  435. # with tqdm(total=file_size, disable=not progress,
  436. # unit='B', unit_scale=True, unit_divisor=1024) as pbar:
  437. # while True:
  438. # buffer = u.read(8192)
  439. # if len(buffer) == 0:
  440. # break
  441. # f.write(buffer)
  442. # if hash_prefix is not None:
  443. # sha256.update(buffer)
  444. # pbar.update(len(buffer))
  445. #
  446. # f.close()
  447. # if hash_prefix is not None:
  448. # digest = sha256.hexdigest()
  449. # if digest[:len(hash_prefix)] != hash_prefix:
  450. # raise RuntimeError('invalid hash value (expected "{}", got "{}")'
  451. # .format(hash_prefix, digest))
  452. # shutil.move(f.name, dst)
  453. # finally:
  454. # f.close()
  455. # if os.path.exists(f.name):
  456. # os.remove(f.name)
  457. #
  458. # def _download_url_to_file(url, dst, hash_prefix=None, progress=True):
  459. # warnings.warn('torch.hub._download_url_to_file has been renamed to\
  460. # torch.hub.download_url_to_file to be a public API,\
  461. # _download_url_to_file will be removed in after 1.3 release')
  462. # download_url_to_file(url, dst, hash_prefix, progress)
  463. #
  464. # # Hub used to support automatically extracts from zipfile manually compressed by users.
  465. # # The legacy zip format expects only one file from torch.save() < 1.6 in the zip.
  466. # # We should remove this support since zipfile is now default zipfile format for torch.save().
  467. # def _is_legacy_zip_format(filename):
  468. # if zipfile.is_zipfile(filename):
  469. # infolist = zipfile.ZipFile(filename).infolist()
  470. # return len(infolist) == 1 and not infolist[0].is_dir()
  471. # return False
  472. #
  473. # def _legacy_zip_load(filename, model_dir, map_location):
  474. # warnings.warn('Falling back to the old format < 1.6. This support will be '
  475. # 'deprecated in favor of default zipfile format introduced in 1.6. '
  476. # 'Please redo torch.save() to save it in the new zipfile format.')
  477. # # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
  478. # # We deliberately don't handle tarfile here since our legacy serialization format was in tar.
  479. # # E.g. resnet18-5c106cde.pth which is widely used.
  480. # with zipfile.ZipFile(filename) as f:
  481. # members = f.infolist()
  482. # if len(members) != 1:
  483. # raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
  484. # f.extractall(model_dir)
  485. # extraced_name = members[0].filename
  486. # extracted_file = os.path.join(model_dir, extraced_name)
  487. # return torch.load(extracted_file, map_location=map_location)
  488. #
  489. # def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
  490. # r"""Loads the Torch serialized object at the given URL.
  491. #
  492. # If downloaded file is a zip file, it will be automatically
  493. # decompressed.
  494. #
  495. # If the object is already present in `model_dir`, it's deserialized and
  496. # returned.
  497. # The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
  498. # `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
  499. #
  500. # Args:
  501. # url (string): URL of the object to download
  502. # model_dir (string, optional): directory in which to save the object
  503. # map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
  504. # progress (bool, optional): whether or not to display a progress bar to stderr.
  505. # Default: True
  506. # check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
  507. # ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
  508. # digits of the SHA256 hash of the contents of the file. The hash is used to
  509. # ensure unique names and to verify the contents of the file.
  510. # Default: False
  511. # file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set.
  512. #
  513. # Example:
  514. # >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
  515. #
  516. # """
  517. # # Issue warning to move data if old env is set
  518. # if os.getenv('TORCH_MODEL_ZOO'):
  519. # warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
  520. #
  521. # if model_dir is None:
  522. # hub_dir = get_dir()
  523. # model_dir = os.path.join(hub_dir, 'checkpoints')
  524. #
  525. # try:
  526. # os.makedirs(model_dir)
  527. # except OSError as e:
  528. # if e.errno == errno.EEXIST:
  529. # # Directory already exists, ignore.
  530. # pass
  531. # else:
  532. # # Unexpected OSError, re-raise.
  533. # raise
  534. #
  535. # parts = urlparse(url)
  536. # filename = os.path.basename(parts.path)
  537. # if file_name is not None:
  538. # filename = file_name
  539. # cached_file = os.path.join(model_dir, filename)
  540. # if not os.path.exists(cached_file):
  541. # sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
  542. # hash_prefix = None
  543. # if check_hash:
  544. # r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
  545. # hash_prefix = r.group(1) if r else None
  546. # download_url_to_file(url, cached_file, hash_prefix, progress=progress)
  547. #
  548. # if _is_legacy_zip_format(cached_file):
  549. # return _legacy_zip_load(cached_file, model_dir, map_location)
  550. # return torch.load(cached_file, map_location=map_location)

全栈的自动化机器学习系统,主要针对多变量时间序列数据的异常检测。TODS提供了详尽的用于构建基于机器学习的异常检测系统的模块,它们包括:数据处理(data processing),时间序列处理( time series processing),特征分析(feature analysis),检测算法(detection algorithms),和强化模块( reinforcement module)。这些模块所提供的功能包括常见的数据预处理、时间序列数据的平滑或变换,从时域或频域中抽取特征、多种多样的检测算