# Modified from https://github.com/pytorch/vision import os import sys import tarfile import collections import torch.utils.data as data import shutil import numpy as np from .utils import colormap from torchvision.datasets import VisionDataset import torch from PIL import Image from torchvision.datasets.utils import download_url, check_integrity DATASET_YEAR_DICT = { '2012aug': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 'filename': 'VOCtrainval_11-May-2012.tar', 'md5': '6cd6e144f989b92b3379bac3b3de84fd', 'base_dir': 'VOCdevkit/VOC2012' }, '2012': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 'filename': 'VOCtrainval_11-May-2012.tar', 'md5': '6cd6e144f989b92b3379bac3b3de84fd', 'base_dir': 'VOCdevkit/VOC2012' }, '2011': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', 'filename': 'VOCtrainval_25-May-2011.tar', 'md5': '6c3384ef61512963050cb5d687e5bf1e', 'base_dir': 'TrainVal/VOCdevkit/VOC2011' }, '2010': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', 'filename': 'VOCtrainval_03-May-2010.tar', 'md5': 'da459979d0c395079b5c75ee67908abb', 'base_dir': 'VOCdevkit/VOC2010' }, '2009': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', 'filename': 'VOCtrainval_11-May-2009.tar', 'md5': '59065e4b188729180974ef6572f6a212', 'base_dir': 'VOCdevkit/VOC2009' }, '2008': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', 'filename': 'VOCtrainval_11-May-2012.tar', 'md5': '2629fa636546599198acfcfbfcf1904a', 'base_dir': 'VOCdevkit/VOC2008' }, '2007': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 'filename': 'VOCtrainval_06-Nov-2007.tar', 'md5': 'c52e279531787c972589f7e41ab4ae64', 'base_dir': 'VOCdevkit/VOC2007' } } class VOCSegmentation(VisionDataset): """`Pascal VOC `_ Segmentation Dataset. Args: root (string): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years 2007 to 2012. image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ cmap = colormap() def __init__(self, root, year='2012', image_set='train', download=False, transform=None, target_transform=None, transforms=None, ): super( VOCSegmentation, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) is_aug=False if year=='2012aug': is_aug = True year = '2012' self.root = os.path.expanduser(root) self.year = year self.url = DATASET_YEAR_DICT[year]['url'] self.filename = DATASET_YEAR_DICT[year]['filename'] self.md5 = DATASET_YEAR_DICT[year]['md5'] self.image_set = image_set base_dir = DATASET_YEAR_DICT[year]['base_dir'] voc_root = os.path.join(self.root, base_dir) image_dir = os.path.join(voc_root, 'JPEGImages') if download: download_extract(self.url, self.root, self.filename, self.md5) if not os.path.isdir(voc_root): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') if is_aug and image_set=='train': mask_dir = os.path.join(voc_root, 'SegmentationClassAug') assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually" split_f = os.path.join( self.root, 'train_aug.txt') else: mask_dir = os.path.join(voc_root, 'SegmentationClass') splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') if not os.path.exists(split_f): raise ValueError( 'Wrong image_set entered! Please use image_set="train" ' 'or image_set="trainval" or image_set="val"') with open(os.path.join(split_f), "r") as f: file_names = [x.strip() for x in f.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] assert (len(self.images) == len(self.masks)) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is the image segmentation. """ img = Image.open(self.images[index]).convert('RGB') target = Image.open(self.masks[index]) if self.transforms is not None: img, target = self.transforms(img, target) return img, target.squeeze(0) def __len__(self): return len(self.images) @classmethod def decode_fn(cls, mask): """decode semantic mask to RGB image""" return cls.cmap[mask] def download_extract(url, root, filename, md5): download_url(url, root, filename, md5) with tarfile.open(os.path.join(root, filename), "r") as tar: tar.extractall(path=root) CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] class VOCClassification(data.Dataset): def __init__(self, root, year='2010', split='train', download=False, transforms=None, target_transforms=None): voc_root = os.path.join(root, 'VOC{}'.format(year)) if not os.path.isdir(voc_root): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') self.transforms = transforms self.target_transforms = target_transforms image_dir = os.path.join(voc_root, 'JPEGImages') label_dir = os.path.join(voc_root, 'ImageSets/Main') self.labels_list = [] fname = os.path.join(label_dir, '{}.txt'.format(split)) with open(fname) as f: self.images = [os.path.join(image_dir, line.split()[0]+'.jpg') for line in f] for clas in CLASSES: labels = [] with open(os.path.join(label_dir, '{}_{}.txt'.format(clas, split))) as f: labels = [int(line.split()[1]) for line in f] self.labels_list.append(labels) assert (len(self.images) == len(self.labels_list[0])) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is the image segmentation. """ img = Image.open(self.images[index]).convert('RGB') labels = [labels[index] for labels in self.labels_list] if self.transforms is not None: img = self.transforms(img) if self.target_transforms is not None: labels = self.target_transforms(labels) return img, labels def __len__(self): return len(self.images)