# Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/caltech.py from __future__ import print_function from PIL import Image import os import os.path from torchvision.datasets.vision import VisionDataset from torchvision.datasets.utils import download_url class Caltech101(VisionDataset): """`Caltech 101 `_ Dataset. Args: root (string): Root directory of dataset where directory ``caltech101`` exists or will be saved to if download is set to True. target_type (string or list, optional): Type of target to use, ``category`` or ``annotation``. Can also be a list to output a tuple with all specified target types. ``category`` represents the target class, and ``annotation`` is a list of points from a hand-generated outline. Defaults to ``category``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ def __init__(self, root, target_type="category", train=True, transform=None, target_transform=None, download=False): super(Caltech101, self).__init__(os.path.join(root, 'caltech101')) self.train = train self.dir_name = '101_ObjectCategories_split/train' if self.train else '101_ObjectCategories_split/test' os.makdirs(self.root, exist_ok=True) if isinstance(target_type, list): self.target_type = target_type else: self.target_type = [target_type] self.transform = transform self.target_transform = target_transform if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) self.categories.remove("BACKGROUND_Google") # this is not a real class # For some reason, the category names in "101_ObjectCategories" and # "Annotations" do not always match. This is a manual map between the # two. Defaults to using same name, since most names are fine. name_map = {"Faces": "Faces_2", "Faces_easy": "Faces_3", "Motorbikes": "Motorbikes_16", "airplanes": "Airplanes_Side_2"} self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) self.index = [] self.y = [] for (i, c) in enumerate(self.categories): file_names = os.listdir(os.path.join(self.root, self.dir_name, c)) n = len(file_names) self.index.extend( file_names ) self.y.extend(n * [i]) print(self.train, len(self.index)) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where the type of target specified by target_type. """ import scipy.io img = Image.open(os.path.join(self.root, self.dir_name, self.categories[self.y[index]], self.index[index])).convert("RGB") target = [] for t in self.target_type: if t == "category": target.append(self.y[index]) elif t == "annotation": data = scipy.io.loadmat(os.path.join(self.root, "Annotations", self.annotation_categories[self.y[index]], "annotation_{:04d}.mat".format(self.index[index]))) target.append(data["obj_contour"]) else: raise ValueError("Target type \"{}\" is not recognized.".format(t)) target = tuple(target) if len(target) > 1 else target[0] if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def _check_integrity(self): # can be more robust and check hash of files return os.path.exists(os.path.join(self.root, "101_ObjectCategories")) def __len__(self): return len(self.index) def download(self): import tarfile if self._check_integrity(): print('Files already downloaded and verified') return download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", self.root, "101_ObjectCategories.tar.gz", "b224c7392d521a49829488ab0f1120d9") download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", self.root, "101_Annotations.tar", "6f83eeb1f24d99cab4eb377263132c91") # extract file with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar: tar.extractall(path=self.root) with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar: tar.extractall(path=self.root) def extra_repr(self): return "Target type: {target_type}".format(**self.__dict__) class Caltech256(VisionDataset): """`Caltech 256 `_ Dataset. Args: root (string): Root directory of dataset where directory ``caltech256`` exists or will be saved to if download is set to True. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ def __init__(self, root, transform=None, target_transform=None, download=False): super(Caltech256, self).__init__(os.path.join(root, 'caltech256')) os.makedirs(self.root, exist_ok=True) self.transform = transform self.target_transform = target_transform if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) self.index = [] self.y = [] for (i, c) in enumerate(self.categories): n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c))) self.index.extend(range(1, n + 1)) self.y.extend(n * [i]) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img = Image.open(os.path.join(self.root, "256_ObjectCategories", self.categories[self.y[index]], "{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]))) target = self.y[index] if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def _check_integrity(self): # can be more robust and check hash of files return os.path.exists(os.path.join(self.root, "256_ObjectCategories")) def __len__(self): return len(self.index) def download(self): import tarfile if self._check_integrity(): print('Files already downloaded and verified') return download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", self.root, "256_ObjectCategories.tar", "67b4f42ca05d46448c6bb8ecd2220f6d") # extract file with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar: tar.extractall(path=self.root)