import os import numpy as np from PIL import Image from scipy.io import loadmat from torch.utils import data from .utils import download_url from shutil import move class StanfordDogs(data.Dataset): """Dataset for Stanford Dogs """ urls = {"images.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar", "annotation.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar", "lists.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"} def __init__(self, root, split='train', download=False, transform=None, target_transform=None): self.root = os.path.abspath( os.path.expanduser(root) ) self.split = split self.transform = transform self.target_transform = target_transform if download: self.download() list_file = os.path.join(self.root, self.split+'_list.mat') mat_file = loadmat(list_file) size = len(mat_file['file_list']) self.files = [str(mat_file['file_list'][i][0][0]) for i in range(size)] self.labels = np.array( [mat_file['labels'][i][0]-1 for i in range(size)]) categories = os.listdir(os.path.join(self.root, 'Images')) categories.sort() self.object_categories = [c[10:] for c in categories] print('Stanford Dogs, Split: %s, Size: %d' % (self.split, self.__len__())) def __len__(self): return len(self.files) def __getitem__(self, idx): img = Image.open(os.path.join(self.root, 'Images', self.files[idx])).convert("RGB") lbl = self.labels[idx] if self.transform is not None: img = self.transform(img) if self.target_transform is not None: lbl = self.target_transform( lbl ) return img, lbl def download(self): import tarfile os.makedirs(self.root, exist_ok=True) for fname, url in self.urls.items(): if not os.path.isfile(os.path.join(self.root, fname)): download_url(url, self.root, fname) # extract file print("Extracting %s..." % fname) with tarfile.open(os.path.join(self.root, fname), "r") as tar: tar.extractall(path=self.root)