You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cityscapes.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # ---------------------------------------------------------------------
  10. # Part of the following code in this file refs to torchvision
  11. # BSD 3-Clause License
  12. #
  13. # Copyright (c) Soumith Chintala 2016,
  14. # All rights reserved.
  15. # ---------------------------------------------------------------------
  16. import json
  17. import os
  18. import cv2
  19. import numpy as np
  20. from .meta_vision import VisionDataset
  21. class Cityscapes(VisionDataset):
  22. r"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset."""
  23. supported_order = (
  24. "image",
  25. "mask",
  26. "info",
  27. )
  28. def __init__(self, root, image_set, mode, *, order=None):
  29. super().__init__(root, order=order, supported_order=self.supported_order)
  30. city_root = self.root
  31. if not os.path.isdir(city_root):
  32. raise RuntimeError("Dataset not found or corrupted.")
  33. self.mode = mode
  34. self.images_dir = os.path.join(city_root, "leftImg8bit", image_set)
  35. self.masks_dir = os.path.join(city_root, self.mode, image_set)
  36. self.images, self.masks = [], []
  37. # self.target_type = ["instance", "semantic", "polygon", "color"]
  38. # for semantic segmentation
  39. if mode == "gtFine":
  40. valid_modes = ("train", "test", "val")
  41. else:
  42. valid_modes = ("train", "train_extra", "val")
  43. for city in os.listdir(self.images_dir):
  44. img_dir = os.path.join(self.images_dir, city)
  45. mask_dir = os.path.join(self.masks_dir, city)
  46. for file_name in os.listdir(img_dir):
  47. mask_name = "{}_{}".format(
  48. file_name.split("_leftImg8bit")[0],
  49. self._get_target_suffix(self.mode, "semantic"),
  50. )
  51. self.images.append(os.path.join(img_dir, file_name))
  52. self.masks.append(os.path.join(mask_dir, mask_name))
  53. def __getitem__(self, index):
  54. target = []
  55. for k in self.order:
  56. if k == "image":
  57. image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
  58. target.append(image)
  59. elif k == "mask":
  60. mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
  61. mask = self._trans_mask(mask)
  62. mask = mask[:, :, np.newaxis]
  63. target.append(mask)
  64. elif k == "info":
  65. if image is None:
  66. image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
  67. info = [image.shape[0], image.shape[1], self.images[index]]
  68. target.append(info)
  69. else:
  70. raise NotImplementedError
  71. return tuple(target)
  72. def __len__(self):
  73. return len(self.images)
  74. def _trans_mask(self, mask):
  75. trans_labels = [
  76. 7,
  77. 8,
  78. 11,
  79. 12,
  80. 13,
  81. 17,
  82. 19,
  83. 20,
  84. 21,
  85. 22,
  86. 23,
  87. 24,
  88. 25,
  89. 26,
  90. 27,
  91. 28,
  92. 31,
  93. 32,
  94. 33,
  95. ]
  96. label = np.ones(mask.shape) * 255
  97. for i, tl in enumerate(trans_labels):
  98. label[mask == tl] = i
  99. return label.astype(np.uint8)
  100. def _get_target_suffix(self, mode, target_type):
  101. if target_type == "instance":
  102. return "{}_instanceIds.png".format(mode)
  103. elif target_type == "semantic":
  104. return "{}_labelIds.png".format(mode)
  105. elif target_type == "color":
  106. return "{}_color.png".format(mode)
  107. else:
  108. return "{}_polygons.json".format(mode)
  109. def _load_json(self, path):
  110. with open(path, "r") as file:
  111. data = json.load(file)
  112. return data
  113. class_names = (
  114. "road",
  115. "sidewalk",
  116. "building",
  117. "wall",
  118. "fence",
  119. "pole",
  120. "traffic light",
  121. "traffic sign",
  122. "vegetation",
  123. "terrain",
  124. "sky",
  125. "person",
  126. "rider",
  127. "car",
  128. "truck",
  129. "bus",
  130. "train",
  131. "motorcycle",
  132. "bicycle",
  133. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台