from __future__ import print_function import sys import os, sys, tarfile, errno import numpy as np import matplotlib.pyplot as plt import argparse if sys.version_info >= (3, 0, 0): import urllib.request as urllib # ugly but works else: import urllib try: from imageio import imsave except: from scipy.misc import imsave print(sys.version_info) # image shape HEIGHT = 96 WIDTH = 96 DEPTH = 3 # size of a single image in bytes SIZE = HEIGHT * WIDTH * DEPTH # path to the directory with the data #DATA_DIR = './data' # url of the binary data #DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz' # path to the binary train file with image data #DATA_PATH = './data/stl10_binary/train_X.bin' # path to the binary train file with labels #LABEL_PATH = './data/stl10_binary/train_y.bin' def read_labels(path_to_labels): """ :param path_to_labels: path to the binary file containing labels from the STL-10 dataset :return: an array containing the labels """ with open(path_to_labels, 'rb') as f: labels = np.fromfile(f, dtype=np.uint8) return labels def read_all_images(path_to_data): """ :param path_to_data: the file containing the binary images from the STL-10 dataset :return: an array containing all the images """ with open(path_to_data, 'rb') as f: # read whole file in uint8 chunks everything = np.fromfile(f, dtype=np.uint8) # We force the data into 3x96x96 chunks, since the # images are stored in "column-major order", meaning # that "the first 96*96 values are the red channel, # the next 96*96 are green, and the last are blue." # The -1 is since the size of the pictures depends # on the input file, and this way numpy determines # the size on its own. images = np.reshape(everything, (-1, 3, 96, 96)) # Now transpose the images into a standard image format # readable by, for example, matplotlib.imshow # You might want to comment this line or reverse the shuffle # if you will use a learning algorithm like CNN, since they like # their channels separated. images = np.transpose(images, (0, 3, 2, 1)) return images def read_single_image(image_file): """ CAREFUL! - this method uses a file as input instead of the path - so the position of the reader will be remembered outside of context of this method. :param image_file: the open file containing the images :return: a single image """ # read a single image, count determines the number of uint8's to read image = np.fromfile(image_file, dtype=np.uint8, count=SIZE) # force into image matrix image = np.reshape(image, (3, 96, 96)) # transpose to standard format # You might want to comment this line or reverse the shuffle # if you will use a learning algorithm like CNN, since they like # their channels separated. image = np.transpose(image, (2, 1, 0)) return image def plot_image(image): """ :param image: the image to be plotted in a 3-D matrix format :return: None """ plt.imshow(image) plt.show() def save_image(image, name): imsave("%s.png" % name, image, format="png") def download_and_extract(DATA_DIR): """ Download and extract the STL-10 dataset :return: None """ dest_directory = DATA_DIR if not os.path.exists(dest_directory): os.makedirs(dest_directory) filename = DATA_URL.split('/')[-1] filepath = os.path.join(dest_directory, filename) if not os.path.exists(filepath): def _progress(count, block_size, total_size): sys.stdout.write('\rDownloading %s %.2f%%' % (filename, float(count * block_size) / float(total_size) * 100.0)) sys.stdout.flush() filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress) print('Downloaded', filename) tarfile.open(filepath, 'r:gz').extractall(dest_directory) def save_images(images, labels, save_dir): print("Saving images to disk") i = 0 if not os.path.exists(save_dir): os.mkdir(save_dir) for image in images: if labels is not None: label = labels[i] directory = os.path.join( save_dir, str(label) ) if not os.path.exists(directory): os.makedirs(directory) else: directory = save_dir filename = os.path.join( directory, str(i) ) print(filename) save_image(image, filename) i = i+1 if __name__ == "__main__": # download data if needed # download_and_extract() parser = argparse.ArgumentParser() parser.add_argument('--DATA_DIR', type=str, default='../stl10_binary') args = parser.parse_args() BASE_PATH = os.path.join( args.DATA_DIR, 'stl10_data' ) if not os.path.exists(BASE_PATH): os.mkdir(BASE_PATH) # Train DATA_PATH = os.path.join( args.DATA_DIR, 'train_X.bin') LABEL_PATH = os.path.join( args.DATA_DIR, 'train_y.bin') print('Preparing Train') images = read_all_images(DATA_PATH) print(images.shape) labels = read_labels(LABEL_PATH) print(labels.shape) save_images(images, labels, os.path.join(BASE_PATH, 'train')) # Test DATA_PATH = os.path.join( args.DATA_DIR, 'test_X.bin') LABEL_PATH = os.path.join( args.DATA_DIR, 'test_y.bin') #with open(DATA_PATH) as f: # image = read_single_image(f) # plot_image(image) print('Preparing Test') images = read_all_images(DATA_PATH) print(images.shape) labels = read_labels(LABEL_PATH) print(labels.shape) save_images(images, labels, os.path.join(BASE_PATH, 'test')) # Unlabeled print('Preparing Unlabeled') DATA_PATH = os.path.join( args.DATA_DIR, 'unlabeled_X.bin') images = read_all_images(DATA_PATH) save_images(images, None, os.path.join(BASE_PATH, 'unlabeled'))