You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cityscapes.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # ---------------------------------------------------------------------
  10. # Part of the following code in this file refs to torchvision
  11. # BSD 3-Clause License
  12. #
  13. # Copyright (c) Soumith Chintala 2016,
  14. # All rights reserved.
  15. # ---------------------------------------------------------------------
  16. import json
  17. import os
  18. import cv2
  19. import numpy as np
  20. from .meta_vision import VisionDataset
  21. class Cityscapes(VisionDataset):
  22. r"""
  23. `Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
  24. """
  25. supported_order = (
  26. "image",
  27. "mask",
  28. "info",
  29. )
  30. def __init__(self, root, image_set, mode, *, order=None):
  31. super().__init__(root, order=order, supported_order=self.supported_order)
  32. city_root = self.root
  33. if not os.path.isdir(city_root):
  34. raise RuntimeError("Dataset not found or corrupted.")
  35. self.mode = mode
  36. self.images_dir = os.path.join(city_root, "leftImg8bit", image_set)
  37. self.masks_dir = os.path.join(city_root, self.mode, image_set)
  38. self.images, self.masks = [], []
  39. # self.target_type = ["instance", "semantic", "polygon", "color"]
  40. # for semantic segmentation
  41. if mode == "gtFine":
  42. valid_modes = ("train", "test", "val")
  43. else:
  44. valid_modes = ("train", "train_extra", "val")
  45. for city in os.listdir(self.images_dir):
  46. img_dir = os.path.join(self.images_dir, city)
  47. mask_dir = os.path.join(self.masks_dir, city)
  48. for file_name in os.listdir(img_dir):
  49. mask_name = "{}_{}".format(
  50. file_name.split("_leftImg8bit")[0],
  51. self._get_target_suffix(self.mode, "semantic"),
  52. )
  53. self.images.append(os.path.join(img_dir, file_name))
  54. self.masks.append(os.path.join(mask_dir, mask_name))
  55. def __getitem__(self, index):
  56. target = []
  57. for k in self.order:
  58. if k == "image":
  59. image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
  60. target.append(image)
  61. elif k == "mask":
  62. mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
  63. mask = self._trans_mask(mask)
  64. mask = mask[:, :, np.newaxis]
  65. target.append(mask)
  66. elif k == "info":
  67. if image is None:
  68. image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
  69. info = [image.shape[0], image.shape[1], self.images[index]]
  70. target.append(info)
  71. else:
  72. raise NotImplementedError
  73. return tuple(target)
  74. def __len__(self):
  75. return len(self.images)
  76. def _trans_mask(self, mask):
  77. trans_labels = [
  78. 7,
  79. 8,
  80. 11,
  81. 12,
  82. 13,
  83. 17,
  84. 19,
  85. 20,
  86. 21,
  87. 22,
  88. 23,
  89. 24,
  90. 25,
  91. 26,
  92. 27,
  93. 28,
  94. 31,
  95. 32,
  96. 33,
  97. ]
  98. label = np.ones(mask.shape) * 255
  99. for i, tl in enumerate(trans_labels):
  100. label[mask == tl] = i
  101. return label.astype(np.uint8)
  102. def _get_target_suffix(self, mode, target_type):
  103. if target_type == "instance":
  104. return "{}_instanceIds.png".format(mode)
  105. elif target_type == "semantic":
  106. return "{}_labelIds.png".format(mode)
  107. elif target_type == "color":
  108. return "{}_color.png".format(mode)
  109. else:
  110. return "{}_polygons.json".format(mode)
  111. def _load_json(self, path):
  112. with open(path, "r") as file:
  113. data = json.load(file)
  114. return data
  115. class_names = (
  116. "road",
  117. "sidewalk",
  118. "building",
  119. "wall",
  120. "fence",
  121. "pole",
  122. "traffic light",
  123. "traffic sign",
  124. "vegetation",
  125. "terrain",
  126. "sky",
  127. "person",
  128. "rider",
  129. "car",
  130. "truck",
  131. "bus",
  132. "train",
  133. "motorcycle",
  134. "bicycle",
  135. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台