#Modified from https://github.com/pytorch/vision/pull/467/files from __future__ import print_function import torch.utils.data as data from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader from PIL import Image import os import numpy as np from .utils import download_url, mkdir def make_dataset(dir, image_ids, targets): assert(len(image_ids) == len(targets)) images = [] dir = os.path.expanduser(dir) for i in range(len(image_ids)): item = (os.path.join(dir, 'fgvc-aircraft-2013b', 'data', 'images', '%s.jpg' % image_ids[i]), targets[i]) images.append(item) return images def find_classes(classes_file): # read classes file, separating out image IDs and class names image_ids = [] targets = [] f = open(classes_file, 'r') for line in f: split_line = line.split(' ') image_ids.append(split_line[0]) targets.append(' '.join(split_line[1:])) f.close() # index class names classes = np.unique(targets) class_to_idx = {classes[i]: i for i in range(len(classes))} targets = [class_to_idx[c] for c in targets] return (image_ids, targets, classes, class_to_idx) class FGVCAircraft(data.Dataset): """`FGVC-Aircraft `_ Dataset. Args: root (string): Root directory path to dataset. class_type (string, optional): The level of FGVC-Aircraft fine-grain classification to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). transforms (callable, optional): A function/transforms that takes in a PIL image and returns a transformed version. E.g. ``transforms.RandomCrop`` target_transform (callable, optional): A function/transforms that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. download (bool, optional): If true, downloads the dataset from the internet and puts it in the root directory. If dataset is already downloaded, it is not downloaded again. """ url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' class_types = ('variant', 'family', 'manufacturer') splits = ('train', 'val', 'trainval', 'test') def __init__(self, root, class_type='variant', split='train', transform=None, target_transform=None, loader=default_loader, download=False): if split not in self.splits: raise ValueError('Split "{}" not found. Valid splits are: {}'.format( split, ', '.join(self.splits), )) if class_type not in self.class_types: raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( class_type, ', '.join(self.class_types), )) self.root = root self.class_type = class_type self.split = split self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data', 'images_%s_%s.txt' % (self.class_type, self.split)) if download: self.download() (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) samples = make_dataset(self.root, image_ids, targets) self.transform = transform self.target_transform = target_transform self.loader = loader self.samples = samples self.classes = classes self.class_to_idx = class_to_idx with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'variants.txt')) as f: self.object_categories = [ line.strip('\n') for line in f.readlines()] print('FGVC-Aircraft, Split: %s, Size: %d' % (self.split, self.__len__())) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self): return len(self.samples) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format( tmp, self.transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format( tmp, self.target_transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str def _check_exists(self): return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ os.path.exists(self.classes_file) def download(self): """Download the FGVC-Aircraft data if it doesn't exist already.""" from six.moves import urllib import tarfile mkdir(self.root) fpath = os.path.join(self.root, 'fgvc-aircraft-2013b.tar.gz') if not os.path.isfile(fpath): download_url(self.url, self.root, 'fgvc-aircraft-2013b.tar.gz') print("Extracting fgvc-aircraft-2013b.tar.gz...") with tarfile.open(fpath, "r:gz") as tar: tar.extractall(path=self.root)