You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cityscapes.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # ---------------------------------------------------------------------
  10. # Part of the following code in this file refs to torchvision
  11. # BSD 3-Clause License
  12. #
  13. # Copyright (c) Soumith Chintala 2016,
  14. # All rights reserved.
  15. # ---------------------------------------------------------------------
  16. import json
  17. import os
  18. import cv2
  19. import numpy as np
  20. from .meta_vision import VisionDataset
  21. class Cityscapes(VisionDataset):
  22. r"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
  23. """
  24. supported_order = (
  25. "image",
  26. "mask",
  27. "info",
  28. )
  29. def __init__(self, root, image_set, mode, *, order=None):
  30. super().__init__(root, order=order, supported_order=self.supported_order)
  31. city_root = self.root
  32. if not os.path.isdir(city_root):
  33. raise RuntimeError("Dataset not found or corrupted.")
  34. self.mode = mode
  35. self.images_dir = os.path.join(city_root, "leftImg8bit", image_set)
  36. self.masks_dir = os.path.join(city_root, self.mode, image_set)
  37. self.images, self.masks = [], []
  38. # self.target_type = ["instance", "semantic", "polygon", "color"]
  39. # for semantic segmentation
  40. if mode == "gtFine":
  41. valid_modes = ("train", "test", "val")
  42. else:
  43. valid_modes = ("train", "train_extra", "val")
  44. for city in os.listdir(self.images_dir):
  45. img_dir = os.path.join(self.images_dir, city)
  46. mask_dir = os.path.join(self.masks_dir, city)
  47. for file_name in os.listdir(img_dir):
  48. mask_name = "{}_{}".format(
  49. file_name.split("_leftImg8bit")[0],
  50. self._get_target_suffix(self.mode, "semantic"),
  51. )
  52. self.images.append(os.path.join(img_dir, file_name))
  53. self.masks.append(os.path.join(mask_dir, mask_name))
  54. def __getitem__(self, index):
  55. target = []
  56. for k in self.order:
  57. if k == "image":
  58. image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
  59. target.append(image)
  60. elif k == "mask":
  61. mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
  62. mask = self._trans_mask(mask)
  63. mask = mask[:, :, np.newaxis]
  64. target.append(mask)
  65. elif k == "info":
  66. if image is None:
  67. image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
  68. info = [image.shape[0], image.shape[1], self.images[index]]
  69. target.append(info)
  70. else:
  71. raise NotImplementedError
  72. return tuple(target)
  73. def __len__(self):
  74. return len(self.images)
  75. def _trans_mask(self, mask):
  76. trans_labels = [
  77. 7,
  78. 8,
  79. 11,
  80. 12,
  81. 13,
  82. 17,
  83. 19,
  84. 20,
  85. 21,
  86. 22,
  87. 23,
  88. 24,
  89. 25,
  90. 26,
  91. 27,
  92. 28,
  93. 31,
  94. 32,
  95. 33,
  96. ]
  97. label = np.ones(mask.shape) * 255
  98. for i, tl in enumerate(trans_labels):
  99. label[mask == tl] = i
  100. return label.astype(np.uint8)
  101. def _get_target_suffix(self, mode, target_type):
  102. if target_type == "instance":
  103. return "{}_instanceIds.png".format(mode)
  104. elif target_type == "semantic":
  105. return "{}_labelIds.png".format(mode)
  106. elif target_type == "color":
  107. return "{}_color.png".format(mode)
  108. else:
  109. return "{}_polygons.json".format(mode)
  110. def _load_json(self, path):
  111. with open(path, "r") as file:
  112. data = json.load(file)
  113. return data
  114. class_names = (
  115. "road",
  116. "sidewalk",
  117. "building",
  118. "wall",
  119. "fence",
  120. "pole",
  121. "traffic light",
  122. "traffic sign",
  123. "vegetation",
  124. "terrain",
  125. "sky",
  126. "person",
  127. "rider",
  128. "car",
  129. "truck",
  130. "bus",
  131. "train",
  132. "motorcycle",
  133. "bicycle",
  134. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台