You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

coco.py 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # ---------------------------------------------------------------------
  10. # Part of the following code in this file refs to maskrcnn-benchmark
  11. # MIT License
  12. #
  13. # Copyright (c) 2018 Facebook
  14. # ---------------------------------------------------------------------
  15. from collections import OrderedDict, defaultdict
  16. import json
  17. import os
  18. import cv2
  19. import numpy as np
  20. from .meta_vision import VisionDataset
  21. min_keypoints_per_image = 10
  22. def _count_visible_keypoints(anno):
  23. return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
  24. def _has_only_empty_bbox(anno):
  25. return all(any(o <= 0 for o in obj["bbox"][2:]) for obj in anno)
  26. def has_valid_annotation(anno):
  27. # if it"s empty, there is no annotation
  28. if len(anno) == 0:
  29. return False
  30. # if all boxes have close to zero area, there is no annotation
  31. if _has_only_empty_bbox(anno):
  32. return False
  33. # keypoints task have a slight different critera for considering
  34. # if an annotation is valid
  35. if "keypoints" not in anno[0]:
  36. return True
  37. # for keypoint detection tasks, only consider valid images those
  38. # containing at least min_keypoints_per_image
  39. if _count_visible_keypoints(anno) >= min_keypoints_per_image:
  40. return True
  41. return False
  42. class COCO(VisionDataset):
  43. r"""`MS COCO <http://cocodataset.org/#home>`_ Dataset.
  44. """
  45. supported_order = (
  46. "image",
  47. "boxes",
  48. "boxes_category",
  49. # TODO: need to check
  50. # "keypoints",
  51. # "polygons",
  52. "info",
  53. )
  54. def __init__(
  55. self, root, ann_file, remove_images_without_annotations=False, *, order=None
  56. ):
  57. super().__init__(root, order=order, supported_order=self.supported_order)
  58. with open(ann_file, "r") as f:
  59. dataset = json.load(f)
  60. self.imgs = OrderedDict()
  61. for img in dataset["images"]:
  62. # for saving memory
  63. if "license" in img:
  64. del img["license"]
  65. if "coco_url" in img:
  66. del img["coco_url"]
  67. if "date_captured" in img:
  68. del img["date_captured"]
  69. if "flickr_url" in img:
  70. del img["flickr_url"]
  71. self.imgs[img["id"]] = img
  72. self.img_to_anns = defaultdict(list)
  73. for ann in dataset["annotations"]:
  74. # for saving memory
  75. if (
  76. "boxes" not in self.order
  77. and "boxes_category" not in self.order
  78. and "bbox" in ann
  79. ):
  80. del ann["bbox"]
  81. if "polygons" not in self.order and "segmentation" in ann:
  82. del ann["segmentation"]
  83. self.img_to_anns[ann["image_id"]].append(ann)
  84. self.cats = OrderedDict()
  85. for cat in dataset["categories"]:
  86. self.cats[cat["id"]] = cat
  87. self.ids = list(sorted(self.imgs.keys()))
  88. # filter images without detection annotations
  89. if remove_images_without_annotations:
  90. ids = []
  91. for img_id in self.ids:
  92. anno = self.img_to_anns[img_id]
  93. if has_valid_annotation(anno):
  94. ids.append(img_id)
  95. self.ids = ids
  96. self.json_category_id_to_contiguous_id = {
  97. v: i + 1 for i, v in enumerate(self.cats.keys())
  98. }
  99. self.contiguous_category_id_to_json_id = {
  100. v: k for k, v in self.json_category_id_to_contiguous_id.items()
  101. }
  102. def __getitem__(self, index):
  103. img_id = self.ids[index]
  104. anno = self.img_to_anns[img_id]
  105. # filter crowd annotations
  106. anno = [obj for obj in anno if obj["iscrowd"] == 0]
  107. # filter empty annotations
  108. anno = [obj for obj in anno if obj["area"] > 0]
  109. target = []
  110. for k in self.order:
  111. if k == "image":
  112. file_name = self.imgs[img_id]["file_name"]
  113. path = os.path.join(self.root, file_name)
  114. image = cv2.imread(path, cv2.IMREAD_COLOR)
  115. target.append(image)
  116. elif k == "boxes":
  117. boxes = [obj["bbox"] for obj in anno]
  118. boxes = np.array(boxes).reshape(-1, 4)
  119. # transfer boxes from xywh to xyxy
  120. boxes[:, 2:] += boxes[:, :2]
  121. target.append(boxes)
  122. elif k == "boxes_category":
  123. boxes_category = [obj["category_id"] for obj in anno]
  124. boxes_category = [
  125. self.json_category_id_to_contiguous_id[c] for c in boxes_category
  126. ]
  127. boxes_category = np.array(boxes_category)
  128. target.append(boxes_category)
  129. # TODO: need to check
  130. # elif k == "keypoints":
  131. # keypoints = [obj["keypoints"] for obj in anno]
  132. # keypoints = np.array(keypoints).reshape(-1, len(self.keypoint_names), 3)
  133. # target.append(keypoints)
  134. # elif k == "polygons":
  135. # polygons = [obj["segmentation"] for obj in anno]
  136. # polygons = [[np.array(p).reshape(-1, 2) for p in ps] for ps in polygons]
  137. # target.append(polygons)
  138. elif k == "info":
  139. info = self.imgs[img_id]
  140. info = [info["height"], info["width"], info["file_name"]]
  141. target.append(info)
  142. else:
  143. raise NotImplementedError
  144. return tuple(target)
  145. def __len__(self):
  146. return len(self.ids)
  147. def get_img_info(self, index):
  148. img_id = self.ids[index]
  149. img_info = self.imgs[img_id]
  150. return img_info
  151. class_names = (
  152. "background",
  153. "person",
  154. "bicycle",
  155. "car",
  156. "motorcycle",
  157. "airplane",
  158. "bus",
  159. "train",
  160. "truck",
  161. "boat",
  162. "traffic light",
  163. "fire hydrant",
  164. "stop sign",
  165. "parking meter",
  166. "bench",
  167. "bird",
  168. "cat",
  169. "dog",
  170. "horse",
  171. "sheep",
  172. "cow",
  173. "elephant",
  174. "bear",
  175. "zebra",
  176. "giraffe",
  177. "backpack",
  178. "umbrella",
  179. "handbag",
  180. "tie",
  181. "suitcase",
  182. "frisbee",
  183. "skis",
  184. "snowboard",
  185. "sports ball",
  186. "kite",
  187. "baseball bat",
  188. "baseball glove",
  189. "skateboard",
  190. "surfboard",
  191. "tennis racket",
  192. "bottle",
  193. "wine glass",
  194. "cup",
  195. "fork",
  196. "knife",
  197. "spoon",
  198. "bowl",
  199. "banana",
  200. "apple",
  201. "sandwich",
  202. "orange",
  203. "broccoli",
  204. "carrot",
  205. "hot dog",
  206. "pizza",
  207. "donut",
  208. "cake",
  209. "chair",
  210. "couch",
  211. "potted plant",
  212. "bed",
  213. "dining table",
  214. "toilet",
  215. "tv",
  216. "laptop",
  217. "mouse",
  218. "remote",
  219. "keyboard",
  220. "cell phone",
  221. "microwave",
  222. "oven",
  223. "toaster",
  224. "sink",
  225. "refrigerator",
  226. "book",
  227. "clock",
  228. "vase",
  229. "scissors",
  230. "teddy bear",
  231. "hair drier",
  232. "toothbrush",
  233. )
  234. classes_originID = {
  235. "person": 1,
  236. "bicycle": 2,
  237. "car": 3,
  238. "motorcycle": 4,
  239. "airplane": 5,
  240. "bus": 6,
  241. "train": 7,
  242. "truck": 8,
  243. "boat": 9,
  244. "traffic light": 10,
  245. "fire hydrant": 11,
  246. "stop sign": 13,
  247. "parking meter": 14,
  248. "bench": 15,
  249. "bird": 16,
  250. "cat": 17,
  251. "dog": 18,
  252. "horse": 19,
  253. "sheep": 20,
  254. "cow": 21,
  255. "elephant": 22,
  256. "bear": 23,
  257. "zebra": 24,
  258. "giraffe": 25,
  259. "backpack": 27,
  260. "umbrella": 28,
  261. "handbag": 31,
  262. "tie": 32,
  263. "suitcase": 33,
  264. "frisbee": 34,
  265. "skis": 35,
  266. "snowboard": 36,
  267. "sports ball": 37,
  268. "kite": 38,
  269. "baseball bat": 39,
  270. "baseball glove": 40,
  271. "skateboard": 41,
  272. "surfboard": 42,
  273. "tennis racket": 43,
  274. "bottle": 44,
  275. "wine glass": 46,
  276. "cup": 47,
  277. "fork": 48,
  278. "knife": 49,
  279. "spoon": 50,
  280. "bowl": 51,
  281. "banana": 52,
  282. "apple": 53,
  283. "sandwich": 54,
  284. "orange": 55,
  285. "broccoli": 56,
  286. "carrot": 57,
  287. "hot dog": 58,
  288. "pizza": 59,
  289. "donut": 60,
  290. "cake": 61,
  291. "chair": 62,
  292. "couch": 63,
  293. "potted plant": 64,
  294. "bed": 65,
  295. "dining table": 67,
  296. "toilet": 70,
  297. "tv": 72,
  298. "laptop": 73,
  299. "mouse": 74,
  300. "remote": 75,
  301. "keyboard": 76,
  302. "cell phone": 77,
  303. "microwave": 78,
  304. "oven": 79,
  305. "toaster": 80,
  306. "sink": 81,
  307. "refrigerator": 82,
  308. "book": 84,
  309. "clock": 85,
  310. "vase": 86,
  311. "scissors": 87,
  312. "teddy bear": 88,
  313. "hair drier": 89,
  314. "toothbrush": 90,
  315. }
  316. keypoint_names = (
  317. "nose",
  318. "left_eye",
  319. "right_eye",
  320. "left_ear",
  321. "right_ear",
  322. "left_shoulder",
  323. "right_shoulder",
  324. "left_elbow",
  325. "right_elbow",
  326. "left_wrist",
  327. "right_wrist",
  328. "left_hip",
  329. "right_hip",
  330. "left_knee",
  331. "right_knee",
  332. "left_ankle",
  333. "right_ankle",
  334. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)