You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prepare_stl10.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from __future__ import print_function
  2. import sys
  3. import os, sys, tarfile, errno
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. import argparse
  7. if sys.version_info >= (3, 0, 0):
  8. import urllib.request as urllib # ugly but works
  9. else:
  10. import urllib
  11. try:
  12. from imageio import imsave
  13. except:
  14. from scipy.misc import imsave
  15. print(sys.version_info)
  16. # image shape
  17. HEIGHT = 96
  18. WIDTH = 96
  19. DEPTH = 3
  20. # size of a single image in bytes
  21. SIZE = HEIGHT * WIDTH * DEPTH
  22. # path to the directory with the data
  23. #DATA_DIR = './data'
  24. # url of the binary data
  25. #DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz'
  26. # path to the binary train file with image data
  27. #DATA_PATH = './data/stl10_binary/train_X.bin'
  28. # path to the binary train file with labels
  29. #LABEL_PATH = './data/stl10_binary/train_y.bin'
  30. def read_labels(path_to_labels):
  31. """
  32. :param path_to_labels: path to the binary file containing labels from the STL-10 dataset
  33. :return: an array containing the labels
  34. """
  35. with open(path_to_labels, 'rb') as f:
  36. labels = np.fromfile(f, dtype=np.uint8)
  37. return labels
  38. def read_all_images(path_to_data):
  39. """
  40. :param path_to_data: the file containing the binary images from the STL-10 dataset
  41. :return: an array containing all the images
  42. """
  43. with open(path_to_data, 'rb') as f:
  44. # read whole file in uint8 chunks
  45. everything = np.fromfile(f, dtype=np.uint8)
  46. # We force the data into 3x96x96 chunks, since the
  47. # images are stored in "column-major order", meaning
  48. # that "the first 96*96 values are the red channel,
  49. # the next 96*96 are green, and the last are blue."
  50. # The -1 is since the size of the pictures depends
  51. # on the input file, and this way numpy determines
  52. # the size on its own.
  53. images = np.reshape(everything, (-1, 3, 96, 96))
  54. # Now transpose the images into a standard image format
  55. # readable by, for example, matplotlib.imshow
  56. # You might want to comment this line or reverse the shuffle
  57. # if you will use a learning algorithm like CNN, since they like
  58. # their channels separated.
  59. images = np.transpose(images, (0, 3, 2, 1))
  60. return images
  61. def read_single_image(image_file):
  62. """
  63. CAREFUL! - this method uses a file as input instead of the path - so the
  64. position of the reader will be remembered outside of context of this method.
  65. :param image_file: the open file containing the images
  66. :return: a single image
  67. """
  68. # read a single image, count determines the number of uint8's to read
  69. image = np.fromfile(image_file, dtype=np.uint8, count=SIZE)
  70. # force into image matrix
  71. image = np.reshape(image, (3, 96, 96))
  72. # transpose to standard format
  73. # You might want to comment this line or reverse the shuffle
  74. # if you will use a learning algorithm like CNN, since they like
  75. # their channels separated.
  76. image = np.transpose(image, (2, 1, 0))
  77. return image
  78. def plot_image(image):
  79. """
  80. :param image: the image to be plotted in a 3-D matrix format
  81. :return: None
  82. """
  83. plt.imshow(image)
  84. plt.show()
  85. def save_image(image, name):
  86. imsave("%s.png" % name, image, format="png")
  87. def download_and_extract(DATA_DIR):
  88. """
  89. Download and extract the STL-10 dataset
  90. :return: None
  91. """
  92. dest_directory = DATA_DIR
  93. if not os.path.exists(dest_directory):
  94. os.makedirs(dest_directory)
  95. filename = DATA_URL.split('/')[-1]
  96. filepath = os.path.join(dest_directory, filename)
  97. if not os.path.exists(filepath):
  98. def _progress(count, block_size, total_size):
  99. sys.stdout.write('\rDownloading %s %.2f%%' % (filename,
  100. float(count * block_size) / float(total_size) * 100.0))
  101. sys.stdout.flush()
  102. filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress)
  103. print('Downloaded', filename)
  104. tarfile.open(filepath, 'r:gz').extractall(dest_directory)
  105. def save_images(images, labels, save_dir):
  106. print("Saving images to disk")
  107. i = 0
  108. if not os.path.exists(save_dir):
  109. os.mkdir(save_dir)
  110. for image in images:
  111. if labels is not None:
  112. label = labels[i]
  113. directory = os.path.join( save_dir, str(label) )
  114. if not os.path.exists(directory):
  115. os.makedirs(directory)
  116. else:
  117. directory = save_dir
  118. filename = os.path.join( directory, str(i) )
  119. print(filename)
  120. save_image(image, filename)
  121. i = i+1
  122. if __name__ == "__main__":
  123. # download data if needed
  124. # download_and_extract()
  125. parser = argparse.ArgumentParser()
  126. parser.add_argument('--DATA_DIR', type=str, default='../stl10_binary')
  127. args = parser.parse_args()
  128. BASE_PATH = os.path.join( args.DATA_DIR, 'stl10_data' )
  129. if not os.path.exists(BASE_PATH):
  130. os.mkdir(BASE_PATH)
  131. # Train
  132. DATA_PATH = os.path.join( args.DATA_DIR, 'train_X.bin')
  133. LABEL_PATH = os.path.join( args.DATA_DIR, 'train_y.bin')
  134. print('Preparing Train')
  135. images = read_all_images(DATA_PATH)
  136. print(images.shape)
  137. labels = read_labels(LABEL_PATH)
  138. print(labels.shape)
  139. save_images(images, labels, os.path.join(BASE_PATH, 'train'))
  140. # Test
  141. DATA_PATH = os.path.join( args.DATA_DIR, 'test_X.bin')
  142. LABEL_PATH = os.path.join( args.DATA_DIR, 'test_y.bin')
  143. #with open(DATA_PATH) as f:
  144. # image = read_single_image(f)
  145. # plot_image(image)
  146. print('Preparing Test')
  147. images = read_all_images(DATA_PATH)
  148. print(images.shape)
  149. labels = read_labels(LABEL_PATH)
  150. print(labels.shape)
  151. save_images(images, labels, os.path.join(BASE_PATH, 'test'))
  152. # Unlabeled
  153. print('Preparing Unlabeled')
  154. DATA_PATH = os.path.join( args.DATA_DIR, 'unlabeled_X.bin')
  155. images = read_all_images(DATA_PATH)
  156. save_images(images, None, os.path.join(BASE_PATH, 'unlabeled'))

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)