You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cityscapes.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # -*- coding: utf-8 -*-
  2. # ---------------------------------------------------------------------
  3. # Part of the following code in this file refs to torchvision
  4. # BSD 3-Clause License
  5. #
  6. # Copyright (c) Soumith Chintala 2016,
  7. # All rights reserved.
  8. # ---------------------------------------------------------------------
  9. import json
  10. import os
  11. import cv2
  12. import numpy as np
  13. from .meta_vision import VisionDataset
  14. class Cityscapes(VisionDataset):
  15. r"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset."""
  16. supported_order = (
  17. "image",
  18. "mask",
  19. "info",
  20. )
  21. def __init__(self, root, image_set, mode, *, order=None):
  22. super().__init__(root, order=order, supported_order=self.supported_order)
  23. city_root = self.root
  24. if not os.path.isdir(city_root):
  25. raise RuntimeError("Dataset not found or corrupted.")
  26. self.mode = mode
  27. self.images_dir = os.path.join(city_root, "leftImg8bit", image_set)
  28. self.masks_dir = os.path.join(city_root, self.mode, image_set)
  29. self.images, self.masks = [], []
  30. # self.target_type = ["instance", "semantic", "polygon", "color"]
  31. # for semantic segmentation
  32. if mode == "gtFine":
  33. valid_modes = ("train", "test", "val")
  34. else:
  35. valid_modes = ("train", "train_extra", "val")
  36. for city in os.listdir(self.images_dir):
  37. img_dir = os.path.join(self.images_dir, city)
  38. mask_dir = os.path.join(self.masks_dir, city)
  39. for file_name in os.listdir(img_dir):
  40. mask_name = "{}_{}".format(
  41. file_name.split("_leftImg8bit")[0],
  42. self._get_target_suffix(self.mode, "semantic"),
  43. )
  44. self.images.append(os.path.join(img_dir, file_name))
  45. self.masks.append(os.path.join(mask_dir, mask_name))
  46. def __getitem__(self, index):
  47. target = []
  48. for k in self.order:
  49. if k == "image":
  50. image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
  51. target.append(image)
  52. elif k == "mask":
  53. mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
  54. mask = self._trans_mask(mask)
  55. mask = mask[:, :, np.newaxis]
  56. target.append(mask)
  57. elif k == "info":
  58. if image is None:
  59. image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
  60. info = [image.shape[0], image.shape[1], self.images[index]]
  61. target.append(info)
  62. else:
  63. raise NotImplementedError
  64. return tuple(target)
  65. def __len__(self):
  66. return len(self.images)
  67. def _trans_mask(self, mask):
  68. trans_labels = [
  69. 7,
  70. 8,
  71. 11,
  72. 12,
  73. 13,
  74. 17,
  75. 19,
  76. 20,
  77. 21,
  78. 22,
  79. 23,
  80. 24,
  81. 25,
  82. 26,
  83. 27,
  84. 28,
  85. 31,
  86. 32,
  87. 33,
  88. ]
  89. label = np.ones(mask.shape) * 255
  90. for i, tl in enumerate(trans_labels):
  91. label[mask == tl] = i
  92. return label.astype(np.uint8)
  93. def _get_target_suffix(self, mode, target_type):
  94. if target_type == "instance":
  95. return "{}_instanceIds.png".format(mode)
  96. elif target_type == "semantic":
  97. return "{}_labelIds.png".format(mode)
  98. elif target_type == "color":
  99. return "{}_color.png".format(mode)
  100. else:
  101. return "{}_polygons.json".format(mode)
  102. def _load_json(self, path):
  103. with open(path, "r") as file:
  104. data = json.load(file)
  105. return data
  106. class_names = (
  107. "road",
  108. "sidewalk",
  109. "building",
  110. "wall",
  111. "fence",
  112. "pole",
  113. "traffic light",
  114. "traffic sign",
  115. "vegetation",
  116. "terrain",
  117. "sky",
  118. "person",
  119. "rider",
  120. "car",
  121. "truck",
  122. "bus",
  123. "train",
  124. "motorcycle",
  125. "bicycle",
  126. )