You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.py 16 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. """
  2. # -*- coding: utf-8 -*-
  3. -----------------------------------------------------------------------------------
  4. # Author: Nguyen Mau Dung
  5. # DoC: 2020.08.17
  6. # email: nguyenmaudung93.kstn@gmail.com
  7. -----------------------------------------------------------------------------------
  8. # Description: Testing script
  9. """
  10. import argparse
  11. import sys
  12. import os
  13. import time
  14. import warnings
  15. warnings.filterwarnings("ignore", category=UserWarning)
  16. from easydict import EasyDict as edict
  17. import cv2
  18. import torch
  19. import numpy as np
  20. import torch.nn.functional as F
  21. src_dir = os.path.dirname(os.path.realpath(__file__))
  22. # while not src_dir.endswith("sfa"):
  23. # src_dir = os.path.dirname(src_dir)
  24. if src_dir not in sys.path:
  25. sys.path.append(src_dir)
  26. from data_process.kitti_dataloader import create_test_dataloader
  27. from models.model_utils import create_model
  28. import config.kitti_config as cnf
  29. def parse_test_configs():
  30. parser = argparse.ArgumentParser(description='Testing config for the Implementation')
  31. parser.add_argument('--saved_fn', type=str, default='fpn_resnet_18', metavar='FN',
  32. help='The name using for saving logs, models,...')
  33. parser.add_argument('-a', '--arch', type=str, default='fpn_resnet_18', metavar='ARCH',
  34. help='The name of the model architecture')
  35. parser.add_argument('--model_dir', type=str,
  36. default='/train_out_model/', metavar='PATH',
  37. help='the path of the pretrained checkpoint')
  38. parser.add_argument('--K', type=int, default=50,
  39. help='the number of top K')
  40. parser.add_argument('--no_cuda', default= False,
  41. help='If true, cuda is not used.')
  42. parser.add_argument('--gpu_idx', default=0, type=int,
  43. help='GPU index to use.')
  44. parser.add_argument('--num_samples', type=int, default=None,
  45. help='Take a subset of the dataset to run and debug')
  46. parser.add_argument('--num_workers', type=int, default=1,
  47. help='Number of threads for loading data')
  48. parser.add_argument('--batch_size', type=int, default=1,
  49. help='mini-batch size (default: 4)')
  50. parser.add_argument('--peak_thresh', type=float, default=0.2)
  51. parser.add_argument('--dataset_dir', type=str,default='/dataset_dir/',
  52. help='If true, the output image of the testing phase will be saved')
  53. parser.add_argument('--results_dir', type=str,default='/results_dir/',
  54. help='If true, the output image of the testing phase will be saved')
  55. parser.add_argument('--save_test_output', type=bool, default=True,
  56. help='save the test output or not')
  57. parser.add_argument('--output_format', type=str, default='txt', metavar='PATH',
  58. help='the type of the test output (support image, video or none)')
  59. parser.add_argument('--output_video_fn', type=str, default='out_fpn_resnet_18', metavar='PATH',
  60. help='the video filename if the output format is video')
  61. parser.add_argument('--output-width', type=int, default=608,
  62. help='the width of showing output, the height maybe vary')
  63. configs = edict(vars(parser.parse_args()))
  64. configs.pin_memory = True
  65. configs.distributed = False # For testing on 1 GPU only
  66. configs.input_size = (1216, 608)
  67. configs.hm_size = (304, 152)
  68. configs.down_ratio = 4
  69. configs.max_objects = 50
  70. configs.imagenet_pretrained = False
  71. configs.head_conv = 64
  72. configs.num_classes = 3
  73. configs.num_center_offset = 2
  74. configs.num_z = 1
  75. configs.num_dim = 3
  76. configs.num_direction = 2 # sin, cos
  77. configs.heads = {
  78. 'hm_cen': configs.num_classes,
  79. 'cen_offset': configs.num_center_offset,
  80. 'direction': configs.num_direction,
  81. 'z_coor': configs.num_z,
  82. 'dim': configs.num_dim
  83. }
  84. configs.num_input_features = 4
  85. ####################################################################
  86. ##############Dataset, Checkpoints, and results dir configs#########
  87. ####################################################################
  88. configs.root_dir = '../'
  89. # configs.dataset_dir = os.path.join(configs.root_dir, 'dataset', 'apollo')
  90. # configs.results_dir_img = os.path.join(configs.results_dir, configs.saved_fn, 'image')
  91. # configs.results_dir_txt = os.path.join(configs.results_dir, configs.saved_fn, 'txt')
  92. # make_folder(configs.results_dir_img)
  93. # make_folder(configs.results_dir_txt)
  94. make_folder(configs.results_dir)
  95. return configs
  96. def _sigmoid(x):
  97. return torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4)
  98. def time_synchronized():
  99. torch.cuda.synchronize() if torch.cuda.is_available() else None
  100. return time.time()
  101. def make_folder(folder_name):
  102. if not os.path.exists(folder_name):
  103. os.makedirs(folder_name)
  104. def drawRotatedBox(img, x, y, w, l, yaw, color):
  105. bev_corners = get_corners(x, y, w, l, yaw)
  106. corners_int = bev_corners.reshape(-1, 1, 2).astype(int)
  107. cv2.polylines(img, [corners_int], True, color, 2)
  108. corners_int = bev_corners.reshape(-1, 2)
  109. cv2.line(img, (int(corners_int[0, 0]), int(corners_int[0, 1])), (int(corners_int[3, 0]), int(corners_int[3, 1])), (255, 255, 0), 2)
  110. # bev image coordinates format
  111. def get_corners(x, y, w, l, yaw):
  112. bev_corners = np.zeros((4, 2), dtype=np.float32)
  113. cos_yaw = np.cos(yaw)
  114. sin_yaw = np.sin(yaw)
  115. # front left
  116. bev_corners[0, 0] = x - w / 2 * cos_yaw - l / 2 * sin_yaw
  117. bev_corners[0, 1] = y - w / 2 * sin_yaw + l / 2 * cos_yaw
  118. # rear left
  119. bev_corners[1, 0] = x - w / 2 * cos_yaw + l / 2 * sin_yaw
  120. bev_corners[1, 1] = y - w / 2 * sin_yaw - l / 2 * cos_yaw
  121. # rear right
  122. bev_corners[2, 0] = x + w / 2 * cos_yaw + l / 2 * sin_yaw
  123. bev_corners[2, 1] = y + w / 2 * sin_yaw - l / 2 * cos_yaw
  124. # front right
  125. bev_corners[3, 0] = x + w / 2 * cos_yaw - l / 2 * sin_yaw
  126. bev_corners[3, 1] = y + w / 2 * sin_yaw + l / 2 * cos_yaw
  127. return bev_corners
  128. def _nms(heat, kernel=3):
  129. pad = (kernel - 1) // 2
  130. hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad)
  131. keep = (hmax == heat).float()
  132. return heat * keep
  133. def _gather_feat(feat, ind, mask=None):
  134. dim = feat.size(2)
  135. ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
  136. feat = feat.gather(1, ind)
  137. if mask is not None:
  138. mask = mask.unsqueeze(2).expand_as(feat)
  139. feat = feat[mask]
  140. feat = feat.view(-1, dim)
  141. return feat
  142. def _transpose_and_gather_feat(feat, ind):
  143. feat = feat.permute(0, 2, 3, 1).contiguous()
  144. feat = feat.view(feat.size(0), -1, feat.size(3))
  145. feat = _gather_feat(feat, ind)
  146. return feat
  147. def _topk(scores, K=40):
  148. batch, cat, height, width = scores.size()
  149. topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
  150. topk_inds = topk_inds % (height * width)
  151. topk_ys = (torch.floor_divide(topk_inds, width)).float()
  152. topk_xs = (topk_inds % width).int().float()
  153. topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
  154. topk_clses = (torch.floor_divide(topk_ind, K)).int()
  155. topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
  156. topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
  157. topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
  158. return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
  159. def decode(hm_cen, cen_offset, direction, z_coor, dim, K=40):
  160. batch_size, num_classes, height, width = hm_cen.size()
  161. hm_cen = _nms(hm_cen)
  162. scores, inds, clses, ys, xs = _topk(hm_cen, K=K)
  163. if cen_offset is not None:
  164. cen_offset = _transpose_and_gather_feat(cen_offset, inds)
  165. cen_offset = cen_offset.view(batch_size, K, 2)
  166. xs = xs.view(batch_size, K, 1) + cen_offset[:, :, 0:1]
  167. ys = ys.view(batch_size, K, 1) + cen_offset[:, :, 1:2]
  168. else:
  169. xs = xs.view(batch_size, K, 1) + 0.5
  170. ys = ys.view(batch_size, K, 1) + 0.5
  171. direction = _transpose_and_gather_feat(direction, inds)
  172. direction = direction.view(batch_size, K, 2)
  173. z_coor = _transpose_and_gather_feat(z_coor, inds)
  174. z_coor = z_coor.view(batch_size, K, 1)
  175. dim = _transpose_and_gather_feat(dim, inds)
  176. dim = dim.view(batch_size, K, 3)
  177. clses = clses.view(batch_size, K, 1).float()
  178. scores = scores.view(batch_size, K, 1)
  179. # (scores x 1, ys x 1, xs x 1, z_coor x 1, dim x 3, direction x 2, clses x 1)
  180. # (scores-0:1, ys-1:2, xs-2:3, z_coor-3:4, dim-4:7, direction-7:9, clses-9:10)
  181. # detections: [batch_size, K, 10]
  182. detections = torch.cat([scores, xs, ys, z_coor, dim, direction, clses], dim=2)
  183. return detections
  184. def get_yaw(direction):
  185. return np.arctan2(direction[:, 0:1], direction[:, 1:2])
  186. def post_processing(detections, num_classes=3, down_ratio=4, peak_thresh=0.2):
  187. """
  188. :param detections: [batch_size, K, 10]
  189. # (scores x 1, xs x 1, ys x 1, z_coor x 1, dim x 3, direction x 2, clses x 1)
  190. # (scores-0:1, xs-1:2, ys-2:3, z_coor-3:4, dim-4:7, direction-7:9, clses-9:10)
  191. :return:
  192. """
  193. # TODO: Need to consider rescale to the original scale: x, y
  194. ret = []
  195. for i in range(detections.shape[0]):
  196. top_preds = {}
  197. classes = detections[i, :, -1]
  198. for j in range(num_classes):
  199. inds = (classes == j)
  200. # x, y, z, h, w, l, yaw
  201. top_preds[j] = np.concatenate([
  202. detections[i, inds, 0:1],
  203. detections[i, inds, 1:2] * down_ratio,
  204. detections[i, inds, 2:3] * down_ratio,
  205. detections[i, inds, 3:4],
  206. detections[i, inds, 4:5],
  207. detections[i, inds, 5:6] / cnf.bound_size_y * cnf.BEV_WIDTH,
  208. detections[i, inds, 6:7] / cnf.bound_size_x * cnf.BEV_HEIGHT,
  209. get_yaw(detections[i, inds, 7:9]).astype(np.float32)], axis=1)
  210. # Filter by peak_thresh
  211. if len(top_preds[j]) > 0:
  212. keep_inds = (top_preds[j][:, 0] > peak_thresh)
  213. top_preds[j] = top_preds[j][keep_inds]
  214. ret.append(top_preds)
  215. return ret
  216. def draw_predictions(img, detections, num_classes=3):
  217. for j in range(num_classes):
  218. if len(detections[j]) > 0:
  219. for det in detections[j]:
  220. # (scores-0:1, x-1:2, y-2:3, z-3:4, dim-4:7, yaw-7:8)
  221. _score, _x, _y, _z, _h, _w, _l, _yaw = det
  222. drawRotatedBox(img, _x, _y, _w, _l, _yaw, cnf.colors[int(j)])
  223. return img
  224. def convert_det_to_real_values(detections, num_classes=3):
  225. kitti_dets = []
  226. for cls_id in range(num_classes):
  227. if len(detections[cls_id]) > 0:
  228. for det in detections[cls_id]:
  229. # (scores-0:1, x-1:2, y-2:3, z-3:4, dim-4:7, yaw-7:8)
  230. _score, _x, _y, _z, _h, _w, _l, _yaw = det
  231. _yaw = round(-_yaw/1, 2)
  232. x = round(_y / cnf.BEV_HEIGHT * cnf.bound_size_x + cnf.boundary['minX'], 2)
  233. y = round(_x / cnf.BEV_WIDTH * cnf.bound_size_y + cnf.boundary['minY'], 2)
  234. z = round(_z + cnf.boundary['minZ'], 2)
  235. w = round(_w / cnf.BEV_WIDTH * cnf.bound_size_y, 2)
  236. l = round(_l / cnf.BEV_HEIGHT * cnf.bound_size_x, 2)
  237. h = round(_h/1, 2)
  238. kitti_dets.append([cls_id, h, w, l, x, y, z, _yaw])
  239. return np.array(kitti_dets)
  240. if __name__ == '__main__':
  241. print("=".ljust(66, "="))
  242. configs = parse_test_configs()
  243. model = create_model(configs)
  244. print('\n\n' + '-*=' * 30 + '\n\n')
  245. # assert os.path.isfile(configs.model_dir), "No file at {}".format(configs.model_dir)
  246. if os.path.isfile(configs.model_dir):
  247. model_path = configs.model_dir
  248. else:
  249. # for file in os.listdir(configs.model_dir):
  250. # model_path = os.path.join(configs.model_dir, file)
  251. # 取最后一个模型
  252. model_path = os.path.join(configs.model_dir, os.listdir(configs.model_dir)[-1])
  253. print('Loaded weights from {}\n'.format(model_path))
  254. # model.load_state_dict(torch.load(model_path))
  255. configs.device = torch.device('cpu' if configs.no_cuda else 'cuda:{}'.format(configs.gpu_idx))
  256. model.load_state_dict(torch.load(model_path, map_location=configs.device))
  257. model = model.to(device=configs.device)
  258. out_cap = None
  259. model.eval()
  260. test_dataloader = create_test_dataloader(configs)
  261. with torch.no_grad():
  262. for batch_idx, batch_data in enumerate(test_dataloader):
  263. bev_maps, metadatas = batch_data
  264. input_bev_maps = bev_maps.to(configs.device, non_blocking=True).float()
  265. t1 = time_synchronized()
  266. outputs = model(input_bev_maps)
  267. outputs['hm_cen'] = _sigmoid(outputs['hm_cen'])
  268. outputs['cen_offset'] = _sigmoid(outputs['cen_offset'])
  269. # detections size (batch_size, K, 10)
  270. detections = decode(outputs['hm_cen'], outputs['cen_offset'], outputs['direction'], outputs['z_coor'],
  271. outputs['dim'], K=configs.K)
  272. detections = detections.cpu().numpy().astype(np.float32)
  273. detections = post_processing(detections, configs.num_classes, configs.down_ratio, configs.peak_thresh)
  274. t2 = time_synchronized()
  275. detections = detections[0] # only first batch
  276. # Draw prediction in the image
  277. bev_map = (bev_maps.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
  278. bev_map = cv2.resize(bev_map, (cnf.BEV_WIDTH, cnf.BEV_HEIGHT))
  279. bev_map = draw_predictions(bev_map, detections.copy(), configs.num_classes)
  280. # Rotate the bev_map
  281. bev_map = cv2.rotate(bev_map, cv2.ROTATE_180)
  282. kitti_dets = convert_det_to_real_values(detections)
  283. print('\tDone testing the {}th sample, time: {:.1f}ms, speed {:.2f}FPS'.format(batch_idx, (t2 - t1) * 1000,
  284. 1 / (t2 - t1)))
  285. if configs.save_test_output:
  286. img_fn = os.path.basename(metadatas['bev_path'][0])[:-4]
  287. if configs.output_format == 'image':
  288. cv2.imwrite(os.path.join(configs.results_dir_img, '{}.jpg'.format(img_fn)), bev_map)
  289. elif configs.output_format == 'video':
  290. if out_cap is None:
  291. out_cap_h, out_cap_w = bev_map.shape[:2]
  292. fourcc = cv2.VideoWriter_fourcc(*'MJPG')
  293. out_cap = cv2.VideoWriter(
  294. os.path.join(configs.results_dir_img, '{}.avi'.format(configs.output_video_fn)),
  295. fourcc, 30, (out_cap_w, out_cap_h))
  296. out_cap.write(bev_map)
  297. else:
  298. pass
  299. txt_path = os.path.join(configs.results_dir,'{}.txt'.format(img_fn))
  300. txt_file = open(txt_path, 'w')
  301. for det in kitti_dets:
  302. write_line = cnf.CLASS_ID_TO_NAME[det[0]] + ' 0 0 0 0 0 0 0 ' + str(det[1]) + ' ' + str(det[2]) +\
  303. ' ' + str(det[3]) + ' ' + str(det[4]) + ' ' + str(det[5]) + ' ' + str(det[6]) + ' ' + str(det[7]) +'\n'
  304. txt_file.writelines(write_line)
  305. txt_file.close()
  306. if out_cap:
  307. out_cap.release()
  308. cv2.destroyAllWindows()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能