You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kitti_dataloader.py 2.9 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. """
  2. # -*- coding: utf-8 -*-
  3. -----------------------------------------------------------------------------------
  4. # Author: Nguyen Mau Dung
  5. # DoC: 2020.08.17
  6. # email: nguyenmaudung93.kstn@gmail.com
  7. -----------------------------------------------------------------------------------
  8. # Description: This script for creating the dataloader for training/validation/test phase
  9. """
  10. import os
  11. import sys
  12. import torch
  13. from torch.utils.data import DataLoader
  14. import numpy as np
  15. src_dir = os.path.dirname(os.path.realpath(__file__))
  16. # while not src_dir.endswith("sfa"):
  17. # src_dir = os.path.dirname(src_dir)
  18. if src_dir not in sys.path:
  19. sys.path.append(src_dir)
  20. from data_process.kitti_dataset import KittiDataset
  21. from data_process.transformation import OneOf, Random_Rotation, Random_Scaling
  22. def create_train_dataloader(configs):
  23. """Create dataloader for training"""
  24. train_lidar_aug = OneOf([
  25. Random_Rotation(limit_angle=np.pi / 4, p=1.0),
  26. Random_Scaling(scaling_range=(0.95, 1.05), p=1.0),
  27. ], p=0.66)
  28. train_dataset = KittiDataset(configs, mode='train', lidar_aug=train_lidar_aug, hflip_prob=configs.hflip_prob,
  29. num_samples=configs.num_samples)
  30. train_sampler = None
  31. if configs.distributed:
  32. train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
  33. train_dataloader = DataLoader(train_dataset, batch_size=configs.batch_size, shuffle=(train_sampler is None),
  34. pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=train_sampler)
  35. return train_dataloader, train_sampler
  36. def create_val_dataloader(configs):
  37. """Create dataloader for validation"""
  38. val_sampler = None
  39. val_dataset = KittiDataset(configs, mode='val', lidar_aug=None, hflip_prob=0., num_samples=configs.num_samples)
  40. if configs.distributed:
  41. val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
  42. val_dataloader = DataLoader(val_dataset, batch_size=configs.batch_size, shuffle=False,
  43. pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=val_sampler)
  44. return val_dataloader
  45. def create_test_dataloader(configs):
  46. """Create dataloader for testing phase"""
  47. test_dataset = KittiDataset(configs, mode='test', lidar_aug=None, hflip_prob=0., num_samples=configs.num_samples)
  48. test_sampler = None
  49. if configs.distributed:
  50. test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
  51. test_dataloader = DataLoader(test_dataset, batch_size=configs.batch_size, shuffle=False,
  52. pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=test_sampler)
  53. return test_dataloader

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能