You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

voc.py 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # ---------------------------------------------------------------------
  10. # Part of the following code in this file refs to torchvision
  11. # BSD 3-Clause License
  12. #
  13. # Copyright (c) Soumith Chintala 2016,
  14. # All rights reserved.
  15. # ---------------------------------------------------------------------
  16. import collections.abc
  17. import os
  18. import xml.etree.ElementTree as ET
  19. import cv2
  20. import numpy as np
  21. from .meta_vision import VisionDataset
  22. class PascalVOC(VisionDataset):
  23. r"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset.
  24. """
  25. supported_order = (
  26. "image",
  27. "boxes",
  28. "boxes_category",
  29. "mask",
  30. "info",
  31. )
  32. def __init__(self, root, image_set, *, order=None):
  33. if ("boxes" in order or "boxes_category" in order) and "mask" in order:
  34. raise ValueError(
  35. "PascalVOC only supports boxes & boxes_category or mask, not both."
  36. )
  37. super().__init__(root, order=order, supported_order=self.supported_order)
  38. if not os.path.isdir(self.root):
  39. raise RuntimeError("Dataset not found or corrupted.")
  40. self.image_set = image_set
  41. image_dir = os.path.join(self.root, "JPEGImages")
  42. if "boxes" in order or "boxes_category" in order:
  43. annotation_dir = os.path.join(self.root, "Annotations")
  44. splitdet_dir = os.path.join(self.root, "ImageSets/Main")
  45. split_f = os.path.join(splitdet_dir, image_set.rstrip("\n") + ".txt")
  46. with open(os.path.join(split_f), "r") as f:
  47. self.file_names = [x.strip() for x in f.readlines()]
  48. self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names]
  49. self.annotations = [
  50. os.path.join(annotation_dir, x + ".xml") for x in self.file_names
  51. ]
  52. assert len(self.images) == len(self.annotations)
  53. elif "mask" in order:
  54. if "aug" in image_set:
  55. mask_dir = os.path.join(self.root, "SegmentationClass_aug")
  56. else:
  57. mask_dir = os.path.join(self.root, "SegmentationClass")
  58. splitmask_dir = os.path.join(self.root, "ImageSets/Segmentation")
  59. split_f = os.path.join(splitmask_dir, image_set.rstrip("\n") + ".txt")
  60. with open(os.path.join(split_f), "r") as f:
  61. self.file_names = [x.strip() for x in f.readlines()]
  62. self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names]
  63. self.masks = [os.path.join(mask_dir, x + ".png") for x in self.file_names]
  64. assert len(self.images) == len(self.masks)
  65. else:
  66. raise NotImplementedError
  67. def __getitem__(self, index):
  68. target = []
  69. for k in self.order:
  70. if k == "image":
  71. image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
  72. target.append(image)
  73. elif k == "boxes":
  74. anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
  75. boxes = [obj["bndbox"] for obj in anno["annotation"]["object"]]
  76. # boxes type xyxy
  77. boxes = [
  78. (bb["xmin"], bb["ymin"], bb["xmax"], bb["ymax"]) for bb in boxes
  79. ]
  80. boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
  81. target.append(boxes)
  82. elif k == "boxes_category":
  83. anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
  84. boxes_category = [obj["name"] for obj in anno["annotation"]["object"]]
  85. boxes_category = [self.class_names.index(bc) for bc in boxes_category]
  86. boxes_category = np.array(boxes_category, dtype=np.int32)
  87. target.append(boxes_category)
  88. elif k == "mask":
  89. if "aug" in self.image_set:
  90. mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
  91. else:
  92. mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR)
  93. mask = self._trans_mask(mask)
  94. mask = mask[:, :, np.newaxis]
  95. target.append(mask)
  96. elif k == "info":
  97. if image is None:
  98. image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
  99. info = [image.shape[0], image.shape[1], self.file_names[index]]
  100. target.append(info)
  101. else:
  102. raise NotImplementedError
  103. return tuple(target)
  104. def __len__(self):
  105. return len(self.images)
  106. def _trans_mask(self, mask):
  107. label = np.ones(mask.shape[:2]) * 255
  108. for i in range(len(self.class_colors)):
  109. b, g, r = self.class_colors[i]
  110. label[
  111. (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r)
  112. ] = i
  113. return label.astype(np.uint8)
  114. def parse_voc_xml(self, node):
  115. voc_dict = {}
  116. children = list(node)
  117. if children:
  118. def_dic = collections.defaultdict(list)
  119. for dc in map(self.parse_voc_xml, children):
  120. for ind, v in dc.items():
  121. def_dic[ind].append(v)
  122. if node.tag == "annotation":
  123. def_dic["object"] = [def_dic["object"]]
  124. voc_dict = {
  125. node.tag: {
  126. ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()
  127. }
  128. }
  129. if node.text:
  130. text = node.text.strip()
  131. if not children:
  132. voc_dict[node.tag] = text
  133. return voc_dict
  134. class_names = (
  135. "aeroplane",
  136. "bicycle",
  137. "bird",
  138. "boat",
  139. "bottle",
  140. "bus",
  141. "car",
  142. "cat",
  143. "chair",
  144. "cow",
  145. "diningtable",
  146. "dog",
  147. "horse",
  148. "motorbike",
  149. "person",
  150. "pottedplant",
  151. "sheep",
  152. "sofa",
  153. "train",
  154. "tvmonitor",
  155. )
  156. class_colors = [
  157. [0, 0, 128],
  158. [0, 128, 0],
  159. [0, 128, 128],
  160. [128, 0, 0],
  161. [128, 0, 128],
  162. [128, 128, 0],
  163. [128, 128, 128],
  164. [0, 0, 64],
  165. [0, 0, 192],
  166. [0, 128, 64],
  167. [0, 128, 192],
  168. [128, 0, 64],
  169. [128, 0, 192],
  170. [128, 128, 64],
  171. [128, 128, 192],
  172. [0, 64, 0],
  173. [0, 64, 128],
  174. [0, 192, 0],
  175. [0, 192, 128],
  176. [128, 64, 0],
  177. ]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台