You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Modified from https://github.com/pytorch/vision
  2. import os
  3. import os.path
  4. import hashlib
  5. import errno
  6. from tqdm import tqdm
  7. import numpy as np
  8. import torch
  9. import random
  10. def mkdir(dir):
  11. if not os.path.isdir(dir):
  12. os.mkdir(dir)
  13. def colormap(N=256, normalized=False):
  14. def bitget(byteval, idx):
  15. return ((byteval & (1 << idx)) != 0)
  16. dtype = 'float32' if normalized else 'uint8'
  17. cmap = np.zeros((N, 3), dtype=dtype)
  18. for i in range(N):
  19. r = g = b = 0
  20. c = i
  21. for j in range(8):
  22. r = r | (bitget(c, 0) << 7-j)
  23. g = g | (bitget(c, 1) << 7-j)
  24. b = b | (bitget(c, 2) << 7-j)
  25. c = c >> 3
  26. cmap[i] = np.array([r, g, b])
  27. cmap = cmap/255 if normalized else cmap
  28. return cmap
  29. DEFAULT_COLORMAP = colormap()
  30. def gen_bar_updater(pbar):
  31. def bar_update(count, block_size, total_size):
  32. if pbar.total is None and total_size:
  33. pbar.total = total_size
  34. progress_bytes = count * block_size
  35. pbar.update(progress_bytes - pbar.n)
  36. return bar_update
  37. def check_integrity(fpath, md5=None):
  38. if md5 is None:
  39. return True
  40. if not os.path.isfile(fpath):
  41. return False
  42. md5o = hashlib.md5()
  43. with open(fpath, 'rb') as f:
  44. # read in 1MB chunks
  45. for chunk in iter(lambda: f.read(1024 * 1024), b''):
  46. md5o.update(chunk)
  47. md5c = md5o.hexdigest()
  48. if md5c != md5:
  49. return False
  50. return True
  51. def makedir_exist_ok(dirpath):
  52. """
  53. Python2 support for os.makedirs(.., exist_ok=True)
  54. """
  55. try:
  56. os.makedirs(dirpath)
  57. except OSError as e:
  58. if e.errno == errno.EEXIST:
  59. pass
  60. else:
  61. raise
  62. def download_url(url, root, filename=None, md5=None):
  63. """Download a file from a url and place it in root.
  64. Args:
  65. url (str): URL to download file from
  66. root (str): Directory to place downloaded file in
  67. filename (str): Name to save the file under. If None, use the basename of the URL
  68. md5 (str): MD5 checksum of the download. If None, do not check
  69. """
  70. from six.moves import urllib
  71. root = os.path.expanduser(root)
  72. if not filename:
  73. filename = os.path.basename(url)
  74. fpath = os.path.join(root, filename)
  75. makedir_exist_ok(root)
  76. # downloads file
  77. if os.path.isfile(fpath) and check_integrity(fpath, md5):
  78. print('Using downloaded and verified file: ' + fpath)
  79. else:
  80. try:
  81. print('Downloading ' + url + ' to ' + fpath)
  82. urllib.request.urlretrieve(
  83. url, fpath,
  84. reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
  85. )
  86. except OSError:
  87. if url[:5] == 'https':
  88. url = url.replace('https:', 'http:')
  89. print('Failed download. Trying https -> http instead.'
  90. ' Downloading ' + url + ' to ' + fpath)
  91. urllib.request.urlretrieve(
  92. url, fpath,
  93. reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
  94. )
  95. def list_dir(root, prefix=False):
  96. """List all directories at a given root
  97. Args:
  98. root (str): Path to directory whose folders need to be listed
  99. prefix (bool, optional): If true, prepends the path to each result, otherwise
  100. only returns the name of the directories found
  101. """
  102. root = os.path.expanduser(root)
  103. directories = list(
  104. filter(
  105. lambda p: os.path.isdir(os.path.join(root, p)),
  106. os.listdir(root)
  107. )
  108. )
  109. if prefix is True:
  110. directories = [os.path.join(root, d) for d in directories]
  111. return directories
  112. def list_files(root, suffix, prefix=False):
  113. """List all files ending with a suffix at a given root
  114. Args:
  115. root (str): Path to directory whose folders need to be listed
  116. suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
  117. It uses the Python "str.endswith" method and is passed directly
  118. prefix (bool, optional): If true, prepends the path to each result, otherwise
  119. only returns the name of the files found
  120. """
  121. root = os.path.expanduser(root)
  122. files = list(
  123. filter(
  124. lambda p: os.path.isfile(os.path.join(
  125. root, p)) and p.endswith(suffix),
  126. os.listdir(root)
  127. )
  128. )
  129. if prefix is True:
  130. files = [os.path.join(root, d) for d in files]
  131. return files
  132. def set_seed(random_seed):
  133. torch.manual_seed(random_seed)
  134. torch.cuda.manual_seed(random_seed)
  135. np.random.seed(random_seed)
  136. random.seed(random_seed)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)