""" # -*- coding: utf-8 -*- ----------------------------------------------------------------------------------- # Author: Nguyen Mau Dung # DoC: 2020.08.17 # email: nguyenmaudung93.kstn@gmail.com ----------------------------------------------------------------------------------- # Description: This script for the KITTI dataset """ import sys import os from builtins import int from glob import glob import numpy as np from torch.utils.data import Dataset import cv2 import torch src_dir = os.path.dirname(os.path.realpath(__file__)) # while not src_dir.endswith("sfa"): # src_dir = os.path.dirname(src_dir) if src_dir not in sys.path: sys.path.append(src_dir) from data_process.kitti_data_utils import get_filtered_lidar from data_process.kitti_bev_utils import makeBEVMap import config.kitti_config as cnf class Demo_KittiDataset(Dataset): def __init__(self, configs): self.dataset_dir = os.path.join(configs.dataset_dir, configs.foldername, configs.foldername[:10], configs.foldername) self.input_size = configs.input_size self.hm_size = configs.hm_size self.num_classes = configs.num_classes self.max_objects = configs.max_objects self.image_dir = os.path.join(self.dataset_dir, "image_02", "data") self.lidar_dir = os.path.join(self.dataset_dir, "velodyne_points", "data") self.label_dir = os.path.join(self.dataset_dir, "label_2", "data") self.sample_id_list = sorted(glob(os.path.join(self.lidar_dir, '*.bin'))) self.sample_id_list = [float(os.path.basename(fn)[:-4]) for fn in self.sample_id_list] self.num_samples = len(self.sample_id_list) def __len__(self): return len(self.sample_id_list) def __getitem__(self, index): pass def load_bevmap_front(self, index): """Load only image for the testing phase""" sample_id = int(self.sample_id_list[index]) img_path, img_rgb = self.get_image(sample_id) lidarData = self.get_lidar(sample_id) front_lidar = get_filtered_lidar(lidarData, cnf.boundary) front_bevmap = makeBEVMap(front_lidar, cnf.boundary) front_bevmap = torch.from_numpy(front_bevmap) metadatas = { 'img_path': img_path, } return metadatas, front_bevmap, img_rgb def load_bevmap_front_vs_back(self, index): """Load only image for the testing phase""" sample_id = int(self.sample_id_list[index]) img_path, img_rgb = self.get_image(sample_id) lidarData = self.get_lidar(sample_id) front_lidar = get_filtered_lidar(lidarData, cnf.boundary) front_bevmap = makeBEVMap(front_lidar, cnf.boundary) front_bevmap = torch.from_numpy(front_bevmap) back_lidar = get_filtered_lidar(lidarData, cnf.boundary_back) back_bevmap = makeBEVMap(back_lidar, cnf.boundary_back) back_bevmap = torch.from_numpy(back_bevmap) metadatas = { 'img_path': img_path, } return metadatas, front_bevmap, back_bevmap, img_rgb def get_image(self, idx): img_path = os.path.join(self.image_dir, '{:010d}.png'.format(idx)) img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) return img_path, img def get_lidar(self, idx): lidar_file = os.path.join(self.lidar_dir, '{:010d}.bin'.format(idx)) # assert os.path.isfile(lidar_file) return np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4)