You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

voc.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # -*- coding: utf-8 -*-
  2. # ---------------------------------------------------------------------
  3. # Part of the following code in this file refs to torchvision
  4. # BSD 3-Clause License
  5. #
  6. # Copyright (c) Soumith Chintala 2016,
  7. # All rights reserved.
  8. # ---------------------------------------------------------------------
  9. import collections.abc
  10. import os
  11. import xml.etree.ElementTree as ET
  12. import cv2
  13. import numpy as np
  14. from .meta_vision import VisionDataset
  15. class PascalVOC(VisionDataset):
  16. r"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset."""
  17. supported_order = (
  18. "image",
  19. "boxes",
  20. "boxes_category",
  21. "mask",
  22. "info",
  23. )
  24. def __init__(self, root, image_set, *, order=None):
  25. if ("boxes" in order or "boxes_category" in order) and "mask" in order:
  26. raise ValueError(
  27. "PascalVOC only supports boxes & boxes_category or mask, not both."
  28. )
  29. super().__init__(root, order=order, supported_order=self.supported_order)
  30. if not os.path.isdir(self.root):
  31. raise RuntimeError("Dataset not found or corrupted.")
  32. self.image_set = image_set
  33. image_dir = os.path.join(self.root, "JPEGImages")
  34. if "boxes" in order or "boxes_category" in order:
  35. annotation_dir = os.path.join(self.root, "Annotations")
  36. splitdet_dir = os.path.join(self.root, "ImageSets/Main")
  37. split_f = os.path.join(splitdet_dir, image_set.rstrip("\n") + ".txt")
  38. with open(os.path.join(split_f), "r") as f:
  39. self.file_names = [x.strip() for x in f.readlines()]
  40. self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names]
  41. self.annotations = [
  42. os.path.join(annotation_dir, x + ".xml") for x in self.file_names
  43. ]
  44. assert len(self.images) == len(self.annotations)
  45. elif "mask" in order:
  46. if "aug" in image_set:
  47. mask_dir = os.path.join(self.root, "SegmentationClass_aug")
  48. else:
  49. mask_dir = os.path.join(self.root, "SegmentationClass")
  50. splitmask_dir = os.path.join(self.root, "ImageSets/Segmentation")
  51. split_f = os.path.join(splitmask_dir, image_set.rstrip("\n") + ".txt")
  52. with open(os.path.join(split_f), "r") as f:
  53. self.file_names = [x.strip() for x in f.readlines()]
  54. self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names]
  55. self.masks = [os.path.join(mask_dir, x + ".png") for x in self.file_names]
  56. assert len(self.images) == len(self.masks)
  57. else:
  58. raise NotImplementedError
  59. self.img_infos = dict()
  60. def __getitem__(self, index):
  61. target = []
  62. for k in self.order:
  63. if k == "image":
  64. image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
  65. target.append(image)
  66. elif k == "boxes":
  67. anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
  68. boxes = [obj["bndbox"] for obj in anno["annotation"]["object"]]
  69. # boxes type xyxy
  70. boxes = [
  71. (bb["xmin"], bb["ymin"], bb["xmax"], bb["ymax"]) for bb in boxes
  72. ]
  73. boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
  74. target.append(boxes)
  75. elif k == "boxes_category":
  76. anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot())
  77. boxes_category = [obj["name"] for obj in anno["annotation"]["object"]]
  78. boxes_category = [
  79. self.class_names.index(bc) + 1 for bc in boxes_category
  80. ]
  81. boxes_category = np.array(boxes_category, dtype=np.int32)
  82. target.append(boxes_category)
  83. elif k == "mask":
  84. if "aug" in self.image_set:
  85. mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
  86. else:
  87. mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR)
  88. mask = self._trans_mask(mask)
  89. mask = mask[:, :, np.newaxis]
  90. target.append(mask)
  91. elif k == "info":
  92. info = self.get_img_info(index, image)
  93. info = [info["height"], info["width"], info["file_name"]]
  94. target.append(info)
  95. else:
  96. raise NotImplementedError
  97. return tuple(target)
  98. def __len__(self):
  99. return len(self.images)
  100. def get_img_info(self, index, image=None):
  101. if index not in self.img_infos:
  102. if image is None:
  103. image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
  104. self.img_infos[index] = dict(
  105. height=image.shape[0],
  106. width=image.shape[1],
  107. file_name=self.file_names[index],
  108. )
  109. return self.img_infos[index]
  110. def _trans_mask(self, mask):
  111. label = np.ones(mask.shape[:2]) * 255
  112. for i in range(len(self.class_colors)):
  113. b, g, r = self.class_colors[i]
  114. label[
  115. (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r)
  116. ] = i
  117. return label.astype(np.uint8)
  118. def parse_voc_xml(self, node):
  119. voc_dict = {}
  120. children = list(node)
  121. if children:
  122. def_dic = collections.defaultdict(list)
  123. for dc in map(self.parse_voc_xml, children):
  124. for ind, v in dc.items():
  125. def_dic[ind].append(v)
  126. if node.tag == "annotation":
  127. def_dic["object"] = [def_dic["object"]]
  128. voc_dict = {
  129. node.tag: {
  130. ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()
  131. }
  132. }
  133. if node.text:
  134. text = node.text.strip()
  135. if not children:
  136. voc_dict[node.tag] = text
  137. return voc_dict
  138. class_names = (
  139. "aeroplane",
  140. "bicycle",
  141. "bird",
  142. "boat",
  143. "bottle",
  144. "bus",
  145. "car",
  146. "cat",
  147. "chair",
  148. "cow",
  149. "diningtable",
  150. "dog",
  151. "horse",
  152. "motorbike",
  153. "person",
  154. "pottedplant",
  155. "sheep",
  156. "sofa",
  157. "train",
  158. "tvmonitor",
  159. )
  160. class_colors = [
  161. [0, 0, 0], # background
  162. [0, 0, 128],
  163. [0, 128, 0],
  164. [0, 128, 128],
  165. [128, 0, 0],
  166. [128, 0, 128],
  167. [128, 128, 0],
  168. [128, 128, 128],
  169. [0, 0, 64],
  170. [0, 0, 192],
  171. [0, 128, 64],
  172. [0, 128, 192],
  173. [128, 0, 64],
  174. [128, 0, 192],
  175. [128, 128, 64],
  176. [128, 128, 192],
  177. [0, 64, 0],
  178. [0, 64, 128],
  179. [0, 192, 0],
  180. [0, 192, 128],
  181. [128, 64, 0],
  182. ]