You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kitti_dataset.py 14 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. """
  2. # -*- coding: utf-8 -*-
  3. -----------------------------------------------------------------------------------
  4. # Author: Nguyen Mau Dung
  5. # DoC: 2020.08.17
  6. # email: nguyenmaudung93.kstn@gmail.com
  7. -----------------------------------------------------------------------------------
  8. # Description: This script for the KITTI dataset
  9. """
  10. import sys
  11. import os
  12. import math
  13. from builtins import int
  14. import numpy as np
  15. from torch.utils.data import Dataset
  16. import cv2
  17. import torch
  18. src_dir = os.path.dirname(os.path.realpath(__file__))
  19. # while not src_dir.endswith("sfa"):
  20. # src_dir = os.path.dirname(src_dir)
  21. if src_dir not in sys.path:
  22. sys.path.append(src_dir)
  23. from data_process.kitti_data_utils import gen_hm_radius, compute_radius, Calibration, get_filtered_lidar
  24. from data_process.kitti_bev_utils import makeBEVMap, drawRotatedBox, get_corners
  25. from data_process import transformation
  26. import config.kitti_config as cnf
  27. class KittiDataset(Dataset):
  28. def __init__(self, configs, mode='train', lidar_aug=None, hflip_prob=None, num_samples=None):
  29. self.dataset_dir = configs.dataset_dir
  30. self.input_size = configs.input_size
  31. self.hm_size = configs.hm_size
  32. self.num_classes = configs.num_classes
  33. self.max_objects = configs.max_objects
  34. assert mode in ['train', 'val', 'test'], 'Invalid mode: {}'.format(mode)
  35. self.mode = mode
  36. self.is_test = (self.mode == 'test')
  37. # sub_folder = 'testing' if self.is_test else 'training'
  38. self.lidar_aug = lidar_aug
  39. self.hflip_prob = hflip_prob
  40. if mode == 'val':
  41. self.val_data_url = configs.val_data_url
  42. self.lidar_dir = os.path.join(self.val_data_url, "velodyne")
  43. self.calib_dir = os.path.join(self.val_data_url, "calib")
  44. self.label_dir = os.path.join(self.val_data_url, "label_2")
  45. # self.image_dir = os.path.join(self.dataset_dir, sub_folder, "image_2")
  46. else:
  47. self.lidar_dir = os.path.join(self.dataset_dir, "velodyne")
  48. self.calib_dir = os.path.join(self.dataset_dir, "calib")
  49. self.label_dir = os.path.join(self.dataset_dir, "label_2")
  50. # split_txt_path = os.path.join('../dataset/apollo/', 'ImageSets', '{}.txt'.format(mode))
  51. sample_list = []
  52. sample_files = os.listdir(self.lidar_dir)
  53. for bin_file in sample_files:
  54. bin_name = bin_file.split('.')[0]
  55. sample_list.append(bin_name)
  56. self.sample_id_list = sample_list
  57. if num_samples is not None:
  58. self.sample_id_list = self.sample_id_list[:num_samples]
  59. self.num_samples = len(self.sample_id_list)
  60. def __len__(self):
  61. return len(self.sample_id_list)
  62. def __getitem__(self, index):
  63. if self.is_test:
  64. return self.load_img_only(index)
  65. else:
  66. return self.load_img_with_targets(index)
  67. def load_img_only(self, index):
  68. """Load only image for the testing phase"""
  69. sample_id = self.sample_id_list[index]
  70. # print(sample_id)
  71. # img_path, img_rgb = self.get_image(sample_id)
  72. lidarData = self.get_lidar(sample_id)
  73. lidarData = get_filtered_lidar(lidarData, cnf.boundary)
  74. bev_map = makeBEVMap(lidarData, cnf.boundary)
  75. bev_map = torch.from_numpy(bev_map)
  76. bev_path = os.path.join(self.lidar_dir, '{}.png'.format(sample_id))
  77. metadatas = {
  78. 'bev_path': bev_path,
  79. }
  80. # return metadatas, bev_map, img_rgb
  81. return bev_map,metadatas
  82. def load_img_with_targets(self, index):
  83. """Load images and targets for the training and validation phase"""
  84. sample_id = self.sample_id_list[index]
  85. # img_path = os.path.join(self.image_dir, '{}.png'.format(sample_id))
  86. lidarData = self.get_lidar(sample_id)
  87. # calib = self.get_calib(sample_id)
  88. labels, has_labels = self.get_label(sample_id)
  89. # if has_labels:
  90. # labels[:, 1:] = transformation.camera_to_lidar_box(labels[:, 1:], calib.V2C, calib.R0, calib.P2)
  91. if self.lidar_aug:
  92. lidarData, labels[:, 1:] = self.lidar_aug(lidarData, labels[:, 1:])
  93. lidarData, labels = get_filtered_lidar(lidarData, cnf.boundary, labels)
  94. bev_map = makeBEVMap(lidarData, cnf.boundary)
  95. bev_map = torch.from_numpy(bev_map)
  96. hflipped = False
  97. if np.random.random() < self.hflip_prob:
  98. hflipped = True
  99. # C, H, W
  100. bev_map = torch.flip(bev_map, [-1])
  101. targets = self.build_targets(labels, hflipped)
  102. # metadatas = {
  103. # 'img_path': img_path,
  104. # 'hflipped': hflipped
  105. # }
  106. # return metadatas, bev_map, targets
  107. return bev_map, targets
  108. def get_image(self, idx):
  109. img_path = os.path.join(self.image_dir, '{}.png'.format(idx))
  110. img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
  111. return img_path, img
  112. def get_calib(self, idx):
  113. calib_file = os.path.join(self.calib_dir, '{}.txt'.format(idx))
  114. # assert os.path.isfile(calib_file)
  115. return Calibration(calib_file)
  116. def get_lidar(self, idx):
  117. lidar_file = os.path.join(self.lidar_dir, '{}.bin'.format(idx))
  118. # assert os.path.isfile(lidar_file)
  119. return np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4)
  120. def get_label(self, idx):
  121. labels = []
  122. label_path = os.path.join(self.label_dir, '{}.txt'.format(idx))
  123. for line in open(label_path, 'r'):
  124. line = line.rstrip()
  125. line_parts = line.split(' ')
  126. obj_name = line_parts[0] # 'Car', 'Pedestrian', ...
  127. cat_id = int(cnf.CLASS_NAME_TO_ID[obj_name])
  128. if cat_id <= -99: # ignore Tram and Misc
  129. continue
  130. truncated = int(float(line_parts[1])) # truncated pixel ratio [0..1]
  131. occluded = int(line_parts[2]) # 0=visible, 1=partly occluded, 2=fully occluded, 3=unknown
  132. alpha = float(line_parts[3]) # object observation angle [-pi..pi]
  133. # xmin, ymin, xmax, ymax
  134. # bbox = np.array([float(line_parts[4]), float(line_parts[5]), float(line_parts[6]), float(line_parts[7])])
  135. # height, width, length (h, w, l)
  136. h, w, l = float(line_parts[8]), float(line_parts[9]), float(line_parts[10])
  137. # location (x,y,z) in camera coord.
  138. x, y, z = float(line_parts[11]), float(line_parts[12]), float(line_parts[13])
  139. ry = float(line_parts[14]) # yaw angle (around Y-axis in camera coordinates) [-pi..pi]
  140. object_label = [cat_id, x, y, z, h, w, l, ry]
  141. labels.append(object_label)
  142. if len(labels) == 0:
  143. labels = np.zeros((1, 8), dtype=np.float32)
  144. has_labels = False
  145. else:
  146. labels = np.array(labels, dtype=np.float32)
  147. has_labels = True
  148. return labels, has_labels
  149. def build_targets(self, labels, hflipped):
  150. minX = cnf.boundary['minX']
  151. maxX = cnf.boundary['maxX']
  152. minY = cnf.boundary['minY']
  153. maxY = cnf.boundary['maxY']
  154. minZ = cnf.boundary['minZ']
  155. maxZ = cnf.boundary['maxZ']
  156. num_objects = min(len(labels), self.max_objects)
  157. hm_l, hm_w = self.hm_size
  158. hm_main_center = np.zeros((self.num_classes, hm_l, hm_w), dtype=np.float32)
  159. cen_offset = np.zeros((self.max_objects, 2), dtype=np.float32)
  160. direction = np.zeros((self.max_objects, 2), dtype=np.float32)
  161. z_coor = np.zeros((self.max_objects, 1), dtype=np.float32)
  162. dimension = np.zeros((self.max_objects, 3), dtype=np.float32)
  163. indices_center = np.zeros((self.max_objects), dtype=np.int64)
  164. obj_mask = np.zeros((self.max_objects), dtype=np.uint8)
  165. for k in range(num_objects):
  166. cls_id, x, y, z, h, w, l, yaw = labels[k]
  167. cls_id = int(cls_id)
  168. # Invert yaw angle
  169. yaw = -yaw
  170. if not ((minX <= x <= maxX) and (minY <= y <= maxY) and (minZ <= z <= maxZ)):
  171. continue
  172. if (h <= 0) or (w <= 0) or (l <= 0):
  173. continue
  174. bbox_l = l / cnf.bound_size_x * hm_l
  175. bbox_w = w / cnf.bound_size_y * hm_w
  176. radius = compute_radius((math.ceil(bbox_l), math.ceil(bbox_w)))
  177. radius = max(0, int(radius))
  178. center_y = (x - minX) / cnf.bound_size_x * hm_l # x --> y (invert to 2D image space)
  179. center_x = (y - minY) / cnf.bound_size_y * hm_w # y --> x
  180. center = np.array([center_x, center_y], dtype=np.float32)
  181. if hflipped:
  182. center[0] = hm_w - center[0] - 1
  183. center_int = center.astype(np.int32)
  184. if cls_id < 0:
  185. ignore_ids = [_ for _ in range(self.num_classes)] if cls_id == - 1 else [- cls_id - 2]
  186. # Consider to make mask ignore
  187. for cls_ig in ignore_ids:
  188. gen_hm_radius(hm_main_center[cls_ig], center_int, radius)
  189. hm_main_center[ignore_ids, center_int[1], center_int[0]] = 0.9999
  190. continue
  191. # Generate heatmaps for main center
  192. gen_hm_radius(hm_main_center[cls_id], center, radius)
  193. # Index of the center
  194. indices_center[k] = center_int[1] * hm_w + center_int[0]
  195. # targets for center offset
  196. cen_offset[k] = center - center_int
  197. # targets for dimension
  198. dimension[k, 0] = h
  199. dimension[k, 1] = w
  200. dimension[k, 2] = l
  201. # targets for direction
  202. direction[k, 0] = math.sin(float(yaw)) # im
  203. direction[k, 1] = math.cos(float(yaw)) # re
  204. # im -->> -im
  205. if hflipped:
  206. direction[k, 0] = - direction[k, 0]
  207. # targets for depth
  208. z_coor[k] = z - minZ
  209. # Generate object masks
  210. obj_mask[k] = 1
  211. targets = {
  212. 'hm_cen': hm_main_center,
  213. 'cen_offset': cen_offset,
  214. 'direction': direction,
  215. 'z_coor': z_coor,
  216. 'dim': dimension,
  217. 'indices_center': indices_center,
  218. 'obj_mask': obj_mask,
  219. }
  220. return targets
  221. def draw_img_with_label(self, index):
  222. sample_id = self.sample_id_list[index]
  223. lidar_path = os.path.join(self.lidar_dir, '{}.bin'.format(sample_id))
  224. lidarData = self.get_lidar(sample_id)
  225. calib = self.get_calib(sample_id)
  226. labels, has_labels = self.get_label(sample_id)
  227. print(lidar_path)
  228. if has_labels:
  229. labels[:, 1:] = transformation.camera_to_lidar_box(labels[:, 1:], calib.V2C, calib.R0, calib.P2)
  230. if self.lidar_aug:
  231. lidarData, labels[:, 1:] = self.lidar_aug(lidarData, labels[:, 1:])
  232. lidarData, labels = get_filtered_lidar(lidarData, cnf.boundary, labels)
  233. bev_map = makeBEVMap(lidarData, cnf.boundary)
  234. print(labels)
  235. return bev_map, labels, lidar_path
  236. if __name__ == '__main__':
  237. from easydict import EasyDict as edict
  238. from data_process.transformation import OneOf, Random_Scaling, Random_Rotation, lidar_to_camera_box
  239. from utils.visualization_utils import merge_rgb_to_bev, show_rgb_image_with_boxes
  240. configs = edict()
  241. configs.distributed = False # For testing
  242. configs.pin_memory = False
  243. configs.num_samples = None
  244. configs.input_size = (1216, 608)
  245. configs.hm_size = (304, 152)
  246. configs.max_objects = 50
  247. configs.num_classes = 3
  248. configs.output_width = 608
  249. # configs.dataset_dir = os.path.join('../../', 'dataset', 'kitti')
  250. # lidar_aug = OneOf([
  251. # Random_Rotation(limit_angle=np.pi / 4, p=1.),
  252. # Random_Scaling(scaling_range=(0.95, 1.05), p=1.),
  253. # ], p=1.)
  254. lidar_aug = None
  255. dataset = KittiDataset(configs, mode='val', lidar_aug=lidar_aug, hflip_prob=0., num_samples=configs.num_samples)
  256. print('\n\nPress n to see the next sample >>> Press Esc to quit...')
  257. for idx in range(len(dataset)):
  258. bev_map, labels, lidar_path = dataset.draw_img_with_label(idx)
  259. calib = Calibration(lidar_path.replace(".bin", ".txt").replace("velodyne", "calib"))
  260. bev_map = (bev_map.transpose(1, 2, 0) * 255).astype(np.uint8)
  261. # bev_map = cv2.resize(bev_map, (cnf.BEV_HEIGHT, cnf.BEV_WIDTH))
  262. print(bev_map.shape)
  263. for box_idx, (cls_id, x, y, z, h, w, l, yaw) in enumerate(labels):
  264. # Draw rotated box
  265. yaw = -yaw
  266. y1 = int((x - cnf.boundary['minX']) / cnf.DISCRETIZATION)
  267. x1 = int((y - cnf.boundary['minY']) / cnf.DISCRETIZATION)
  268. w1 = int(w / cnf.DISCRETIZATION)
  269. l1 = int(l / cnf.DISCRETIZATION)
  270. drawRotatedBox(bev_map, x1, y1, w1, l1, yaw, cnf.colors[int(cls_id)])
  271. # Rotate the bev_map
  272. bev_map = cv2.rotate(bev_map, cv2.ROTATE_180)
  273. # labels[:, 1:] = lidar_to_camera_box(labels[:, 1:], calib.V2C, calib.R0, calib.P2)
  274. cv2.imshow('bev_map', bev_map)
  275. if cv2.waitKey(0) & 0xff == 27:
  276. break

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能