You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

coco.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # ---------------------------------------------------------------------
  10. # Part of the following code in this file refs to maskrcnn-benchmark
  11. # MIT License
  12. #
  13. # Copyright (c) 2018 Facebook
  14. # ---------------------------------------------------------------------
  15. import json
  16. import os
  17. from collections import defaultdict
  18. import cv2
  19. import numpy as np
  20. from .meta_vision import VisionDataset
  21. min_keypoints_per_image = 10
  22. def _count_visible_keypoints(anno):
  23. return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
  24. def has_valid_annotation(anno, order):
  25. # if it"s empty, there is no annotation
  26. if len(anno) == 0:
  27. return False
  28. if "boxes" in order or "boxes_category" in order:
  29. if "bbox" not in anno[0]:
  30. return False
  31. if "keypoints" in order:
  32. if "keypoints" not in anno[0]:
  33. return False
  34. # for keypoint detection tasks, only consider valid images those
  35. # containing at least min_keypoints_per_image
  36. if _count_visible_keypoints(anno) < min_keypoints_per_image:
  37. return False
  38. return True
  39. class COCO(VisionDataset):
  40. r"""`MS COCO <http://cocodataset.org/#home>`_ Dataset.
  41. """
  42. supported_order = (
  43. "image",
  44. "boxes",
  45. "boxes_category",
  46. "keypoints",
  47. # TODO: need to check
  48. # "polygons",
  49. "info",
  50. )
  51. def __init__(
  52. self, root, ann_file, remove_images_without_annotations=False, *, order=None
  53. ):
  54. super().__init__(root, order=order, supported_order=self.supported_order)
  55. with open(ann_file, "r") as f:
  56. dataset = json.load(f)
  57. self.imgs = dict()
  58. for img in dataset["images"]:
  59. # for saving memory
  60. if "license" in img:
  61. del img["license"]
  62. if "coco_url" in img:
  63. del img["coco_url"]
  64. if "date_captured" in img:
  65. del img["date_captured"]
  66. if "flickr_url" in img:
  67. del img["flickr_url"]
  68. self.imgs[img["id"]] = img
  69. self.img_to_anns = defaultdict(list)
  70. for ann in dataset["annotations"]:
  71. # for saving memory
  72. if (
  73. "boxes" not in self.order
  74. and "boxes_category" not in self.order
  75. and "bbox" in ann
  76. ):
  77. del ann["bbox"]
  78. if "polygons" not in self.order and "segmentation" in ann:
  79. del ann["segmentation"]
  80. self.img_to_anns[ann["image_id"]].append(ann)
  81. self.cats = dict()
  82. for cat in dataset["categories"]:
  83. self.cats[cat["id"]] = cat
  84. self.ids = list(sorted(self.imgs.keys()))
  85. # filter images without detection annotations
  86. if remove_images_without_annotations:
  87. ids = []
  88. for img_id in self.ids:
  89. anno = self.img_to_anns[img_id]
  90. # filter crowd annotations
  91. anno = [obj for obj in anno if obj["iscrowd"] == 0]
  92. anno = [
  93. obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0
  94. ]
  95. if has_valid_annotation(anno, order):
  96. ids.append(img_id)
  97. self.img_to_anns[img_id] = anno
  98. else:
  99. del self.imgs[img_id]
  100. del self.img_to_anns[img_id]
  101. self.ids = ids
  102. self.json_category_id_to_contiguous_id = {
  103. v: i + 1 for i, v in enumerate(sorted(self.cats.keys()))
  104. }
  105. self.contiguous_category_id_to_json_id = {
  106. v: k for k, v in self.json_category_id_to_contiguous_id.items()
  107. }
  108. def __getitem__(self, index):
  109. img_id = self.ids[index]
  110. anno = self.img_to_anns[img_id]
  111. target = []
  112. for k in self.order:
  113. if k == "image":
  114. file_name = self.imgs[img_id]["file_name"]
  115. path = os.path.join(self.root, file_name)
  116. image = cv2.imread(path, cv2.IMREAD_COLOR)
  117. target.append(image)
  118. elif k == "boxes":
  119. boxes = [obj["bbox"] for obj in anno]
  120. boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
  121. # transfer boxes from xywh to xyxy
  122. boxes[:, 2:] += boxes[:, :2]
  123. target.append(boxes)
  124. elif k == "boxes_category":
  125. boxes_category = [obj["category_id"] for obj in anno]
  126. boxes_category = [
  127. self.json_category_id_to_contiguous_id[c] for c in boxes_category
  128. ]
  129. boxes_category = np.array(boxes_category, dtype=np.int32)
  130. target.append(boxes_category)
  131. elif k == "keypoints":
  132. keypoints = [obj["keypoints"] for obj in anno]
  133. keypoints = np.array(keypoints, dtype=np.float32).reshape(
  134. -1, len(self.keypoint_names), 3
  135. )
  136. target.append(keypoints)
  137. elif k == "polygons":
  138. polygons = [obj["segmentation"] for obj in anno]
  139. polygons = [
  140. [np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps]
  141. for ps in polygons
  142. ]
  143. target.append(polygons)
  144. elif k == "info":
  145. info = self.imgs[img_id]
  146. info = [info["height"], info["width"], info["file_name"]]
  147. target.append(info)
  148. else:
  149. raise NotImplementedError
  150. return tuple(target)
  151. def __len__(self):
  152. return len(self.ids)
  153. def get_img_info(self, index):
  154. img_id = self.ids[index]
  155. img_info = self.imgs[img_id]
  156. return img_info
  157. class_names = (
  158. "person",
  159. "bicycle",
  160. "car",
  161. "motorcycle",
  162. "airplane",
  163. "bus",
  164. "train",
  165. "truck",
  166. "boat",
  167. "traffic light",
  168. "fire hydrant",
  169. "stop sign",
  170. "parking meter",
  171. "bench",
  172. "bird",
  173. "cat",
  174. "dog",
  175. "horse",
  176. "sheep",
  177. "cow",
  178. "elephant",
  179. "bear",
  180. "zebra",
  181. "giraffe",
  182. "backpack",
  183. "umbrella",
  184. "handbag",
  185. "tie",
  186. "suitcase",
  187. "frisbee",
  188. "skis",
  189. "snowboard",
  190. "sports ball",
  191. "kite",
  192. "baseball bat",
  193. "baseball glove",
  194. "skateboard",
  195. "surfboard",
  196. "tennis racket",
  197. "bottle",
  198. "wine glass",
  199. "cup",
  200. "fork",
  201. "knife",
  202. "spoon",
  203. "bowl",
  204. "banana",
  205. "apple",
  206. "sandwich",
  207. "orange",
  208. "broccoli",
  209. "carrot",
  210. "hot dog",
  211. "pizza",
  212. "donut",
  213. "cake",
  214. "chair",
  215. "couch",
  216. "potted plant",
  217. "bed",
  218. "dining table",
  219. "toilet",
  220. "tv",
  221. "laptop",
  222. "mouse",
  223. "remote",
  224. "keyboard",
  225. "cell phone",
  226. "microwave",
  227. "oven",
  228. "toaster",
  229. "sink",
  230. "refrigerator",
  231. "book",
  232. "clock",
  233. "vase",
  234. "scissors",
  235. "teddy bear",
  236. "hair drier",
  237. "toothbrush",
  238. )
  239. classes_originID = {
  240. "person": 1,
  241. "bicycle": 2,
  242. "car": 3,
  243. "motorcycle": 4,
  244. "airplane": 5,
  245. "bus": 6,
  246. "train": 7,
  247. "truck": 8,
  248. "boat": 9,
  249. "traffic light": 10,
  250. "fire hydrant": 11,
  251. "stop sign": 13,
  252. "parking meter": 14,
  253. "bench": 15,
  254. "bird": 16,
  255. "cat": 17,
  256. "dog": 18,
  257. "horse": 19,
  258. "sheep": 20,
  259. "cow": 21,
  260. "elephant": 22,
  261. "bear": 23,
  262. "zebra": 24,
  263. "giraffe": 25,
  264. "backpack": 27,
  265. "umbrella": 28,
  266. "handbag": 31,
  267. "tie": 32,
  268. "suitcase": 33,
  269. "frisbee": 34,
  270. "skis": 35,
  271. "snowboard": 36,
  272. "sports ball": 37,
  273. "kite": 38,
  274. "baseball bat": 39,
  275. "baseball glove": 40,
  276. "skateboard": 41,
  277. "surfboard": 42,
  278. "tennis racket": 43,
  279. "bottle": 44,
  280. "wine glass": 46,
  281. "cup": 47,
  282. "fork": 48,
  283. "knife": 49,
  284. "spoon": 50,
  285. "bowl": 51,
  286. "banana": 52,
  287. "apple": 53,
  288. "sandwich": 54,
  289. "orange": 55,
  290. "broccoli": 56,
  291. "carrot": 57,
  292. "hot dog": 58,
  293. "pizza": 59,
  294. "donut": 60,
  295. "cake": 61,
  296. "chair": 62,
  297. "couch": 63,
  298. "potted plant": 64,
  299. "bed": 65,
  300. "dining table": 67,
  301. "toilet": 70,
  302. "tv": 72,
  303. "laptop": 73,
  304. "mouse": 74,
  305. "remote": 75,
  306. "keyboard": 76,
  307. "cell phone": 77,
  308. "microwave": 78,
  309. "oven": 79,
  310. "toaster": 80,
  311. "sink": 81,
  312. "refrigerator": 82,
  313. "book": 84,
  314. "clock": 85,
  315. "vase": 86,
  316. "scissors": 87,
  317. "teddy bear": 88,
  318. "hair drier": 89,
  319. "toothbrush": 90,
  320. }
  321. keypoint_names = (
  322. "nose",
  323. "left_eye",
  324. "right_eye",
  325. "left_ear",
  326. "right_ear",
  327. "left_shoulder",
  328. "right_shoulder",
  329. "left_elbow",
  330. "right_elbow",
  331. "left_wrist",
  332. "right_wrist",
  333. "left_hip",
  334. "right_hip",
  335. "left_knee",
  336. "right_knee",
  337. "left_ankle",
  338. "right_ankle",
  339. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台