""" # -*- coding: utf-8 -*- ----------------------------------------------------------------------------------- # Author: Nguyen Mau Dung # DoC: 2020.08.17 # email: nguyenmaudung93.kstn@gmail.com ----------------------------------------------------------------------------------- # Description: This script for creating the dataloader for training/validation/test phase """ import os import sys import torch from torch.utils.data import DataLoader import numpy as np src_dir = os.path.dirname(os.path.realpath(__file__)) # while not src_dir.endswith("sfa"): # src_dir = os.path.dirname(src_dir) if src_dir not in sys.path: sys.path.append(src_dir) from data_process.kitti_dataset import KittiDataset from data_process.transformation import OneOf, Random_Rotation, Random_Scaling def create_train_dataloader(configs): """Create dataloader for training""" train_lidar_aug = OneOf([ Random_Rotation(limit_angle=np.pi / 4, p=1.0), Random_Scaling(scaling_range=(0.95, 1.05), p=1.0), ], p=0.66) train_dataset = KittiDataset(configs, mode='train', lidar_aug=train_lidar_aug, hflip_prob=configs.hflip_prob, num_samples=configs.num_samples) train_sampler = None if configs.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_dataloader = DataLoader(train_dataset, batch_size=configs.batch_size, shuffle=(train_sampler is None), pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=train_sampler) return train_dataloader, train_sampler def create_val_dataloader(configs): """Create dataloader for validation""" val_sampler = None val_dataset = KittiDataset(configs, mode='val', lidar_aug=None, hflip_prob=0., num_samples=configs.num_samples) if configs.distributed: val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) val_dataloader = DataLoader(val_dataset, batch_size=configs.batch_size, shuffle=False, pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=val_sampler) return val_dataloader def create_test_dataloader(configs): """Create dataloader for testing phase""" test_dataset = KittiDataset(configs, mode='test', lidar_aug=None, hflip_prob=0., num_samples=configs.num_samples) test_sampler = None if configs.distributed: test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) test_dataloader = DataLoader(test_dataset, batch_size=configs.batch_size, shuffle=False, pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=test_sampler) return test_dataloader