You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

demo_utils.py 5.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. """
  2. # -*- coding: utf-8 -*-
  3. -----------------------------------------------------------------------------------
  4. # Author: Nguyen Mau Dung
  5. # DoC: 2020.08.17
  6. # email: nguyenmaudung93.kstn@gmail.com
  7. -----------------------------------------------------------------------------------
  8. # Description: Demonstration utils script
  9. """
  10. import argparse
  11. import sys
  12. import os
  13. import warnings
  14. import zipfile
  15. warnings.filterwarnings("ignore", category=UserWarning)
  16. from easydict import EasyDict as edict
  17. import numpy as np
  18. import wget
  19. import torch
  20. import cv2
  21. src_dir = os.path.dirname(os.path.realpath(__file__))
  22. # while not src_dir.endswith("sfa"):
  23. # src_dir = os.path.dirname(src_dir)
  24. if src_dir not in sys.path:
  25. sys.path.append(src_dir)
  26. from utils.misc import make_folder, time_synchronized
  27. from utils.evaluation_utils import decode, post_processing
  28. from utils.torch_utils import _sigmoid
  29. def parse_demo_configs():
  30. parser = argparse.ArgumentParser(description='Demonstration config for the implementation')
  31. parser.add_argument('--saved_fn', type=str, default='fpn_resnet_18', metavar='FN',
  32. help='The name using for saving logs, models,...')
  33. parser.add_argument('-a', '--arch', type=str, default='fpn_resnet_18', metavar='ARCH',
  34. help='The name of the model architecture')
  35. parser.add_argument('--pretrained_path', type=str,
  36. default='../checkpoints/fpn_resnet_18/fpn_resnet_18_epoch_300.pth', metavar='PATH',
  37. help='the path of the pretrained checkpoint')
  38. parser.add_argument('--foldername', type=str, default='2011_09_26_drive_0014_sync', metavar='FN',
  39. help='Folder name for demostration dataset')
  40. parser.add_argument('--K', type=int, default=50,
  41. help='the number of top K')
  42. parser.add_argument('--no_cuda', action='store_true',
  43. help='If true, cuda is not used.')
  44. parser.add_argument('--gpu_idx', default=0, type=int,
  45. help='GPU index to use.')
  46. parser.add_argument('--peak_thresh', type=float, default=0.2)
  47. parser.add_argument('--output_format', type=str, default='image', metavar='PATH',
  48. help='the type of the test output (support image or video)')
  49. parser.add_argument('--output-width', type=int, default=608,
  50. help='the width of showing output, the height maybe vary')
  51. configs = edict(vars(parser.parse_args()))
  52. configs.pin_memory = True
  53. configs.distributed = False # For testing on 1 GPU only
  54. configs.input_size = (608, 608)
  55. configs.hm_size = (152, 152)
  56. configs.down_ratio = 4
  57. configs.max_objects = 50
  58. configs.imagenet_pretrained = False
  59. configs.head_conv = 64
  60. configs.num_classes = 3
  61. configs.num_center_offset = 2
  62. configs.num_z = 1
  63. configs.num_dim = 3
  64. configs.num_direction = 2 # sin, cos
  65. configs.heads = {
  66. 'hm_cen': configs.num_classes,
  67. 'cen_offset': configs.num_center_offset,
  68. 'direction': configs.num_direction,
  69. 'z_coor': configs.num_z,
  70. 'dim': configs.num_dim
  71. }
  72. ####################################################################
  73. ##############Dataset, Checkpoints, and results dir configs#########
  74. ####################################################################
  75. configs.root_dir = '../'
  76. configs.dataset_dir = os.path.join(configs.root_dir, 'dataset', 'kitti', 'demo')
  77. configs.calib_path = os.path.join(configs.root_dir, 'dataset', 'kitti', 'demo', 'calib.txt')
  78. configs.results_dir = os.path.join(configs.root_dir, 'results', configs.saved_fn)
  79. make_folder(configs.results_dir)
  80. return configs
  81. def download_and_unzip(demo_dataset_dir, download_url):
  82. filename = download_url.split('/')[-1]
  83. filepath = os.path.join(demo_dataset_dir, filename)
  84. if os.path.isfile(filepath):
  85. print('The dataset have been downloaded')
  86. return
  87. print('\nDownloading data for demonstration...')
  88. wget.download(download_url, filepath)
  89. print('\nUnzipping the downloaded data...')
  90. with zipfile.ZipFile(filepath, "r") as zip_ref:
  91. zip_ref.extractall(os.path.join(demo_dataset_dir, filename[:-4]))
  92. def do_detect(configs, model, bevmap, is_front):
  93. if not is_front:
  94. bevmap = torch.flip(bevmap, [1, 2])
  95. input_bev_maps = bevmap.unsqueeze(0).to(configs.device, non_blocking=True).float()
  96. t1 = time_synchronized()
  97. outputs = model(input_bev_maps)
  98. outputs['hm_cen'] = _sigmoid(outputs['hm_cen'])
  99. outputs['cen_offset'] = _sigmoid(outputs['cen_offset'])
  100. # detections size (batch_size, K, 10)
  101. detections = decode(outputs['hm_cen'], outputs['cen_offset'], outputs['direction'], outputs['z_coor'],
  102. outputs['dim'], K=configs.K)
  103. detections = detections.cpu().numpy().astype(np.float32)
  104. detections = post_processing(detections, configs.num_classes, configs.down_ratio, configs.peak_thresh)
  105. t2 = time_synchronized()
  106. # Inference speed
  107. fps = 1 / (t2 - t1)
  108. return detections[0], bevmap, fps
  109. def write_credit(img, org_author=(500, 400), text_author='github.com/maudzung', org_fps=(50, 1000), fps=None):
  110. font = cv2.FONT_HERSHEY_SIMPLEX
  111. fontScale = 1
  112. color = (255, 255, 255)
  113. thickness = 2
  114. cv2.putText(img, text_author, org_author, font, fontScale, color, thickness, cv2.LINE_AA)
  115. cv2.putText(img, 'Speed: {:.1f} FPS'.format(fps), org_fps, font, fontScale, color, thickness, cv2.LINE_AA)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能