You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

coco.py 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. # -*- coding: utf-8 -*-
  2. # ---------------------------------------------------------------------
  3. # Part of the following code in this file refs to maskrcnn-benchmark
  4. # MIT License
  5. #
  6. # Copyright (c) 2018 Facebook
  7. # ---------------------------------------------------------------------
  8. import json
  9. import os
  10. from collections import defaultdict
  11. import cv2
  12. import numpy as np
  13. from .meta_vision import VisionDataset
  14. min_keypoints_per_image = 10
  15. def _count_visible_keypoints(anno):
  16. return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
  17. def has_valid_annotation(anno, order):
  18. # if it"s empty, there is no annotation
  19. if len(anno) == 0:
  20. return False
  21. if "boxes" in order or "boxes_category" in order:
  22. if "bbox" not in anno[0]:
  23. return False
  24. if "keypoints" in order:
  25. if "keypoints" not in anno[0]:
  26. return False
  27. # for keypoint detection tasks, only consider valid images those
  28. # containing at least min_keypoints_per_image
  29. if _count_visible_keypoints(anno) < min_keypoints_per_image:
  30. return False
  31. return True
  32. class COCO(VisionDataset):
  33. r"""`MS COCO <http://cocodataset.org/#home>`_ Dataset."""
  34. supported_order = (
  35. "image",
  36. "boxes",
  37. "boxes_category",
  38. "keypoints",
  39. # TODO: need to check
  40. # "polygons",
  41. "info",
  42. )
  43. def __init__(
  44. self, root, ann_file, remove_images_without_annotations=False, *, order=None
  45. ):
  46. super().__init__(root, order=order, supported_order=self.supported_order)
  47. with open(ann_file, "r") as f:
  48. dataset = json.load(f)
  49. self.imgs = dict()
  50. for img in dataset["images"]:
  51. # for saving memory
  52. if "license" in img:
  53. del img["license"]
  54. if "coco_url" in img:
  55. del img["coco_url"]
  56. if "date_captured" in img:
  57. del img["date_captured"]
  58. if "flickr_url" in img:
  59. del img["flickr_url"]
  60. self.imgs[img["id"]] = img
  61. self.img_to_anns = defaultdict(list)
  62. for ann in dataset["annotations"]:
  63. # for saving memory
  64. if (
  65. "boxes" not in self.order
  66. and "boxes_category" not in self.order
  67. and "bbox" in ann
  68. ):
  69. del ann["bbox"]
  70. if "polygons" not in self.order and "segmentation" in ann:
  71. del ann["segmentation"]
  72. self.img_to_anns[ann["image_id"]].append(ann)
  73. self.cats = dict()
  74. for cat in dataset["categories"]:
  75. self.cats[cat["id"]] = cat
  76. self.ids = list(sorted(self.imgs.keys()))
  77. # filter images without detection annotations
  78. if remove_images_without_annotations:
  79. ids = []
  80. for img_id in self.ids:
  81. anno = self.img_to_anns[img_id]
  82. # filter crowd annotations
  83. anno = [obj for obj in anno if obj["iscrowd"] == 0]
  84. anno = [
  85. obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0
  86. ]
  87. if has_valid_annotation(anno, order):
  88. ids.append(img_id)
  89. self.img_to_anns[img_id] = anno
  90. else:
  91. del self.imgs[img_id]
  92. del self.img_to_anns[img_id]
  93. self.ids = ids
  94. self.json_category_id_to_contiguous_id = {
  95. v: i + 1 for i, v in enumerate(sorted(self.cats.keys()))
  96. }
  97. self.contiguous_category_id_to_json_id = {
  98. v: k for k, v in self.json_category_id_to_contiguous_id.items()
  99. }
  100. def __getitem__(self, index):
  101. img_id = self.ids[index]
  102. anno = self.img_to_anns[img_id]
  103. target = []
  104. for k in self.order:
  105. if k == "image":
  106. file_name = self.imgs[img_id]["file_name"]
  107. path = os.path.join(self.root, file_name)
  108. image = cv2.imread(path, cv2.IMREAD_COLOR)
  109. target.append(image)
  110. elif k == "boxes":
  111. boxes = [obj["bbox"] for obj in anno]
  112. boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
  113. # transfer boxes from xywh to xyxy
  114. boxes[:, 2:] += boxes[:, :2]
  115. target.append(boxes)
  116. elif k == "boxes_category":
  117. boxes_category = [obj["category_id"] for obj in anno]
  118. boxes_category = [
  119. self.json_category_id_to_contiguous_id[c] for c in boxes_category
  120. ]
  121. boxes_category = np.array(boxes_category, dtype=np.int32)
  122. target.append(boxes_category)
  123. elif k == "keypoints":
  124. keypoints = [obj["keypoints"] for obj in anno]
  125. keypoints = np.array(keypoints, dtype=np.float32).reshape(
  126. -1, len(self.keypoint_names), 3
  127. )
  128. target.append(keypoints)
  129. elif k == "polygons":
  130. polygons = [obj["segmentation"] for obj in anno]
  131. polygons = [
  132. [np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps]
  133. for ps in polygons
  134. ]
  135. target.append(polygons)
  136. elif k == "info":
  137. info = self.imgs[img_id]
  138. info = [info["height"], info["width"], info["file_name"]]
  139. target.append(info)
  140. else:
  141. raise NotImplementedError
  142. return tuple(target)
  143. def __len__(self):
  144. return len(self.ids)
  145. def get_img_info(self, index):
  146. img_id = self.ids[index]
  147. img_info = self.imgs[img_id]
  148. return img_info
  149. class_names = (
  150. "person",
  151. "bicycle",
  152. "car",
  153. "motorcycle",
  154. "airplane",
  155. "bus",
  156. "train",
  157. "truck",
  158. "boat",
  159. "traffic light",
  160. "fire hydrant",
  161. "stop sign",
  162. "parking meter",
  163. "bench",
  164. "bird",
  165. "cat",
  166. "dog",
  167. "horse",
  168. "sheep",
  169. "cow",
  170. "elephant",
  171. "bear",
  172. "zebra",
  173. "giraffe",
  174. "backpack",
  175. "umbrella",
  176. "handbag",
  177. "tie",
  178. "suitcase",
  179. "frisbee",
  180. "skis",
  181. "snowboard",
  182. "sports ball",
  183. "kite",
  184. "baseball bat",
  185. "baseball glove",
  186. "skateboard",
  187. "surfboard",
  188. "tennis racket",
  189. "bottle",
  190. "wine glass",
  191. "cup",
  192. "fork",
  193. "knife",
  194. "spoon",
  195. "bowl",
  196. "banana",
  197. "apple",
  198. "sandwich",
  199. "orange",
  200. "broccoli",
  201. "carrot",
  202. "hot dog",
  203. "pizza",
  204. "donut",
  205. "cake",
  206. "chair",
  207. "couch",
  208. "potted plant",
  209. "bed",
  210. "dining table",
  211. "toilet",
  212. "tv",
  213. "laptop",
  214. "mouse",
  215. "remote",
  216. "keyboard",
  217. "cell phone",
  218. "microwave",
  219. "oven",
  220. "toaster",
  221. "sink",
  222. "refrigerator",
  223. "book",
  224. "clock",
  225. "vase",
  226. "scissors",
  227. "teddy bear",
  228. "hair drier",
  229. "toothbrush",
  230. )
  231. classes_originID = {
  232. "person": 1,
  233. "bicycle": 2,
  234. "car": 3,
  235. "motorcycle": 4,
  236. "airplane": 5,
  237. "bus": 6,
  238. "train": 7,
  239. "truck": 8,
  240. "boat": 9,
  241. "traffic light": 10,
  242. "fire hydrant": 11,
  243. "stop sign": 13,
  244. "parking meter": 14,
  245. "bench": 15,
  246. "bird": 16,
  247. "cat": 17,
  248. "dog": 18,
  249. "horse": 19,
  250. "sheep": 20,
  251. "cow": 21,
  252. "elephant": 22,
  253. "bear": 23,
  254. "zebra": 24,
  255. "giraffe": 25,
  256. "backpack": 27,
  257. "umbrella": 28,
  258. "handbag": 31,
  259. "tie": 32,
  260. "suitcase": 33,
  261. "frisbee": 34,
  262. "skis": 35,
  263. "snowboard": 36,
  264. "sports ball": 37,
  265. "kite": 38,
  266. "baseball bat": 39,
  267. "baseball glove": 40,
  268. "skateboard": 41,
  269. "surfboard": 42,
  270. "tennis racket": 43,
  271. "bottle": 44,
  272. "wine glass": 46,
  273. "cup": 47,
  274. "fork": 48,
  275. "knife": 49,
  276. "spoon": 50,
  277. "bowl": 51,
  278. "banana": 52,
  279. "apple": 53,
  280. "sandwich": 54,
  281. "orange": 55,
  282. "broccoli": 56,
  283. "carrot": 57,
  284. "hot dog": 58,
  285. "pizza": 59,
  286. "donut": 60,
  287. "cake": 61,
  288. "chair": 62,
  289. "couch": 63,
  290. "potted plant": 64,
  291. "bed": 65,
  292. "dining table": 67,
  293. "toilet": 70,
  294. "tv": 72,
  295. "laptop": 73,
  296. "mouse": 74,
  297. "remote": 75,
  298. "keyboard": 76,
  299. "cell phone": 77,
  300. "microwave": 78,
  301. "oven": 79,
  302. "toaster": 80,
  303. "sink": 81,
  304. "refrigerator": 82,
  305. "book": 84,
  306. "clock": 85,
  307. "vase": 86,
  308. "scissors": 87,
  309. "teddy bear": 88,
  310. "hair drier": 89,
  311. "toothbrush": 90,
  312. }
  313. keypoint_names = (
  314. "nose",
  315. "left_eye",
  316. "right_eye",
  317. "left_ear",
  318. "right_ear",
  319. "left_shoulder",
  320. "right_shoulder",
  321. "left_elbow",
  322. "right_elbow",
  323. "left_wrist",
  324. "right_wrist",
  325. "left_hip",
  326. "right_hip",
  327. "left_knee",
  328. "right_knee",
  329. "left_ankle",
  330. "right_ankle",
  331. )