You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cityscapes.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/cityscapes.py
  2. import json
  3. import os
  4. from collections import namedtuple
  5. import torch
  6. import torch.utils.data as data
  7. from PIL import Image
  8. import numpy as np
  9. from torchvision.datasets import VisionDataset
  10. class Cityscapes(VisionDataset):
  11. """Cityscapes <http://www.cityscapes-dataset.com/> Dataset.
  12. Args:
  13. root (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located.
  14. split (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val'
  15. mode (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types.
  16. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
  17. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
  18. """
  19. # Based on https://github.com/mcordts/cityscapesScripts
  20. CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
  21. 'has_instances', 'ignore_in_eval', 'color'])
  22. classes = [
  23. CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
  24. CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
  25. CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
  26. CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
  27. CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
  28. CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
  29. CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
  30. CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
  31. CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
  32. CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
  33. CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
  34. CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
  35. CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
  36. CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
  37. CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
  38. CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
  39. CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
  40. CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
  41. CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
  42. CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
  43. CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
  44. CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
  45. CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
  46. CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
  47. CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
  48. CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
  49. CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
  50. CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
  51. CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
  52. CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
  53. CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
  54. CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
  55. CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
  56. CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
  57. CityscapesClass('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)),
  58. ]
  59. _TRAIN_ID_TO_COLOR = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)]
  60. _TRAIN_ID_TO_COLOR.append([0, 0, 0])
  61. _TRAIN_ID_TO_COLOR = np.array(_TRAIN_ID_TO_COLOR)
  62. _ID_TO_TRAIN_ID = np.array([c.train_id for c in classes])
  63. def __init__(self, root, split='train', mode='gtFine', target_type='semantic', transform=None, target_transform=None, transforms=None):
  64. super(Cityscapes, self).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms )
  65. self.root = os.path.expanduser(root)
  66. self.mode = mode
  67. self.target_type = target_type
  68. self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
  69. self.targets_dir = os.path.join(self.root, self.mode, split)
  70. self.split = split
  71. self.images = []
  72. self.targets = []
  73. if split not in ['train', 'test', 'val']:
  74. raise ValueError('Invalid split for mode! Please use split="train", split="test"'
  75. ' or split="val"')
  76. if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
  77. raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
  78. ' specified "split" and "mode" are inside the "root" directory')
  79. for city in os.listdir(self.images_dir):
  80. img_dir = os.path.join(self.images_dir, city)
  81. target_dir = os.path.join(self.targets_dir, city)
  82. for file_name in os.listdir(img_dir):
  83. self.images.append(os.path.join(img_dir, file_name))
  84. target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
  85. self._get_target_suffix(self.mode, self.target_type))
  86. self.targets.append(os.path.join(target_dir, target_name))
  87. @classmethod
  88. def encode_target(cls, target):
  89. if isinstance( target, torch.Tensor ):
  90. return torch.from_numpy( cls._ID_TO_TRAIN_ID[np.array(target)] )
  91. else:
  92. return cls._ID_TO_TRAIN_ID[target]
  93. @classmethod
  94. def decode_fn(cls, target):
  95. target[target == 255] = 19
  96. #target = target.astype('uint8') + 1
  97. return cls._TRAIN_ID_TO_COLOR[target]
  98. def __getitem__(self, index):
  99. """
  100. Args:
  101. index (int): Index
  102. Returns:
  103. tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
  104. than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
  105. """
  106. image = Image.open(self.images[index]).convert('RGB')
  107. target = Image.open(self.targets[index])
  108. if self.transforms:
  109. image, target = self.transforms(image, target)
  110. target = self.encode_target(target)
  111. return image, target
  112. def __len__(self):
  113. return len(self.images)
  114. def _load_json(self, path):
  115. with open(path, 'r') as file:
  116. data = json.load(file)
  117. return data
  118. def _get_target_suffix(self, mode, target_type):
  119. if target_type == 'instance':
  120. return '{}_instanceIds.png'.format(mode)
  121. elif target_type == 'semantic':
  122. return '{}_labelIds.png'.format(mode)
  123. elif target_type == 'color':
  124. return '{}_color.png'.format(mode)
  125. elif target_type == 'polygon':
  126. return '{}_polygons.json'.format(mode)
  127. elif target_type == 'depth':
  128. return '{}_disparity.png'.format(mode)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)