# Modified from https://github.com/davidtvs/PyTorch-ENet/blob/master/data/camvid.py import os import torch.utils.data as data from glob import glob from PIL import Image import numpy as np from torchvision.datasets import VisionDataset class CamVid(VisionDataset): """CamVid dataset loader where the dataset is arranged as in https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid. Args: root (string): split (string): The type of dataset: 'train', 'val', 'trainval', or 'test' transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default: None. target_transform (callable, optional): A function/transform that takes in the target and transform it. Default: None. transforms (callable, optional): A function/transform that takes in both the image and target and transform them. Default: None. """ cmap = np.array([ (128, 128, 128), (128, 0, 0), (192, 192, 128), (128, 64, 128), (60, 40, 222), (128, 128, 0), (192, 128, 128), (64, 64, 128), (64, 0, 128), (64, 64, 0), (0, 128, 192), (0, 0, 0), ]) def __init__(self, root, split='train', transform=None, target_transform=None, transforms=None): assert split in ('train', 'val', 'test', 'trainval') super( CamVid, self ).__init__(root=root, transforms=transforms, transform=transform, target_transform=target_transform) self.root = os.path.expanduser(root) self.split = split if split == 'trainval': self.images = glob(os.path.join(self.root, 'train', '*.png')) + glob(os.path.join(self.root, 'val', '*.png')) self.labels = glob(os.path.join(self.root, 'trainannot', '*.png')) + glob(os.path.join(self.root, 'valannot', '*.png')) else: self.images = glob(os.path.join(self.root, self.split, '*.png')) self.labels = glob(os.path.join(self.root, self.split+'annot', '*.png')) self.images.sort() self.labels.sort() def __getitem__(self, idx): """ Args: - index (``int``): index of the item in the dataset Returns: A tuple of ``PIL.Image`` (image, label) where label is the ground-truth of the image. """ img, label = Image.open(self.images[idx]), Image.open(self.labels[idx]) if self.transforms is not None: img, label = self.transforms(img, label) label[label == 11] = 255 # ignore void return img, label.squeeze(0) def __len__(self): return len(self.images) @classmethod def decode_fn(cls, mask): """decode semantic mask to RGB image""" mask[mask == 255] = 11 return cls.cmap[mask]