You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluation_utils.py 6.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. """
  2. # -*- coding: utf-8 -*-
  3. -----------------------------------------------------------------------------------
  4. # Author: Nguyen Mau Dung
  5. # DoC: 2020.08.17
  6. # email: nguyenmaudung93.kstn@gmail.com
  7. -----------------------------------------------------------------------------------
  8. # Description: The utils for evaluation
  9. # Refer from: https://github.com/xingyizhou/CenterNet
  10. """
  11. from __future__ import division
  12. import os
  13. import sys
  14. import torch
  15. import numpy as np
  16. import torch.nn.functional as F
  17. import cv2
  18. src_dir = os.path.dirname(os.path.realpath(__file__))
  19. # while not src_dir.endswith("sfa"):
  20. # src_dir = os.path.dirname(src_dir)
  21. if src_dir not in sys.path:
  22. sys.path.append(src_dir)
  23. import config.kitti_config as cnf
  24. from data_process.kitti_bev_utils import drawRotatedBox
  25. def _nms(heat, kernel=3):
  26. pad = (kernel - 1) // 2
  27. hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad)
  28. keep = (hmax == heat).float()
  29. return heat * keep
  30. def _gather_feat(feat, ind, mask=None):
  31. dim = feat.size(2)
  32. ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
  33. feat = feat.gather(1, ind)
  34. if mask is not None:
  35. mask = mask.unsqueeze(2).expand_as(feat)
  36. feat = feat[mask]
  37. feat = feat.view(-1, dim)
  38. return feat
  39. def _transpose_and_gather_feat(feat, ind):
  40. feat = feat.permute(0, 2, 3, 1).contiguous()
  41. feat = feat.view(feat.size(0), -1, feat.size(3))
  42. feat = _gather_feat(feat, ind)
  43. return feat
  44. def _topk(scores, K=40):
  45. batch, cat, height, width = scores.size()
  46. topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
  47. topk_inds = topk_inds % (height * width)
  48. topk_ys = (torch.floor_divide(topk_inds, width)).float()
  49. topk_xs = (topk_inds % width).int().float()
  50. topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
  51. topk_clses = (torch.floor_divide(topk_ind, K)).int()
  52. topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
  53. topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
  54. topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
  55. return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
  56. def _topk_channel(scores, K=40):
  57. batch, cat, height, width = scores.size()
  58. topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
  59. topk_inds = topk_inds % (height * width)
  60. topk_ys = (topk_inds / width).int().float()
  61. topk_xs = (topk_inds % width).int().float()
  62. return topk_scores, topk_inds, topk_ys, topk_xs
  63. def decode(hm_cen, cen_offset, direction, z_coor, dim, K=40):
  64. batch_size, num_classes, height, width = hm_cen.size()
  65. hm_cen = _nms(hm_cen)
  66. scores, inds, clses, ys, xs = _topk(hm_cen, K=K)
  67. if cen_offset is not None:
  68. cen_offset = _transpose_and_gather_feat(cen_offset, inds)
  69. cen_offset = cen_offset.view(batch_size, K, 2)
  70. xs = xs.view(batch_size, K, 1) + cen_offset[:, :, 0:1]
  71. ys = ys.view(batch_size, K, 1) + cen_offset[:, :, 1:2]
  72. else:
  73. xs = xs.view(batch_size, K, 1) + 0.5
  74. ys = ys.view(batch_size, K, 1) + 0.5
  75. direction = _transpose_and_gather_feat(direction, inds)
  76. direction = direction.view(batch_size, K, 2)
  77. z_coor = _transpose_and_gather_feat(z_coor, inds)
  78. z_coor = z_coor.view(batch_size, K, 1)
  79. dim = _transpose_and_gather_feat(dim, inds)
  80. dim = dim.view(batch_size, K, 3)
  81. clses = clses.view(batch_size, K, 1).float()
  82. scores = scores.view(batch_size, K, 1)
  83. # (scores x 1, ys x 1, xs x 1, z_coor x 1, dim x 3, direction x 2, clses x 1)
  84. # (scores-0:1, ys-1:2, xs-2:3, z_coor-3:4, dim-4:7, direction-7:9, clses-9:10)
  85. # detections: [batch_size, K, 10]
  86. detections = torch.cat([scores, xs, ys, z_coor, dim, direction, clses], dim=2)
  87. return detections
  88. def get_yaw(direction):
  89. return np.arctan2(direction[:, 0:1], direction[:, 1:2])
  90. def post_processing(detections, num_classes=3, down_ratio=4, peak_thresh=0.2):
  91. """
  92. :param detections: [batch_size, K, 10]
  93. # (scores x 1, xs x 1, ys x 1, z_coor x 1, dim x 3, direction x 2, clses x 1)
  94. # (scores-0:1, xs-1:2, ys-2:3, z_coor-3:4, dim-4:7, direction-7:9, clses-9:10)
  95. :return:
  96. """
  97. # TODO: Need to consider rescale to the original scale: x, y
  98. ret = []
  99. for i in range(detections.shape[0]):
  100. top_preds = {}
  101. classes = detections[i, :, -1]
  102. for j in range(num_classes):
  103. inds = (classes == j)
  104. # x, y, z, h, w, l, yaw
  105. top_preds[j] = np.concatenate([
  106. detections[i, inds, 0:1],
  107. detections[i, inds, 1:2] * down_ratio,
  108. detections[i, inds, 2:3] * down_ratio,
  109. detections[i, inds, 3:4],
  110. detections[i, inds, 4:5],
  111. detections[i, inds, 5:6] / cnf.bound_size_y * cnf.BEV_WIDTH,
  112. detections[i, inds, 6:7] / cnf.bound_size_x * cnf.BEV_HEIGHT,
  113. get_yaw(detections[i, inds, 7:9]).astype(np.float32)], axis=1)
  114. # Filter by peak_thresh
  115. if len(top_preds[j]) > 0:
  116. keep_inds = (top_preds[j][:, 0] > peak_thresh)
  117. top_preds[j] = top_preds[j][keep_inds]
  118. ret.append(top_preds)
  119. return ret
  120. def draw_predictions(img, detections, num_classes=3):
  121. for j in range(num_classes):
  122. if len(detections[j]) > 0:
  123. for det in detections[j]:
  124. # (scores-0:1, x-1:2, y-2:3, z-3:4, dim-4:7, yaw-7:8)
  125. _score, _x, _y, _z, _h, _w, _l, _yaw = det
  126. drawRotatedBox(img, _x, _y, _w, _l, _yaw, cnf.colors[int(j)])
  127. return img
  128. def convert_det_to_real_values(detections, num_classes=3):
  129. kitti_dets = []
  130. for cls_id in range(num_classes):
  131. if len(detections[cls_id]) > 0:
  132. for det in detections[cls_id]:
  133. # (scores-0:1, x-1:2, y-2:3, z-3:4, dim-4:7, yaw-7:8)
  134. _score, _x, _y, _z, _h, _w, _l, _yaw = det
  135. _yaw = round(-_yaw, 2)
  136. x = round(_y / cnf.BEV_HEIGHT * cnf.bound_size_x + cnf.boundary['minX'], 2)
  137. y = round(_x / cnf.BEV_WIDTH * cnf.bound_size_y + cnf.boundary['minY'], 2)
  138. z = round(_z + cnf.boundary['minZ'], 2)
  139. w = round(_w / cnf.BEV_WIDTH * cnf.bound_size_y, 2)
  140. l = round(_l / cnf.BEV_HEIGHT * cnf.bound_size_x, 2)
  141. h = round(_h/1, 2)
  142. kitti_dets.append([cls_id, h, w, l, x, y, z, _yaw])
  143. return np.array(kitti_dets)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能