You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

camvid.py 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Modified from https://github.com/davidtvs/PyTorch-ENet/blob/master/data/camvid.py
  2. import os
  3. import torch.utils.data as data
  4. from glob import glob
  5. from PIL import Image
  6. import numpy as np
  7. from torchvision.datasets import VisionDataset
  8. class CamVid(VisionDataset):
  9. """CamVid dataset loader where the dataset is arranged as in https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid.
  10. Args:
  11. root (string):
  12. split (string): The type of dataset: 'train', 'val', 'trainval', or 'test'
  13. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default: None.
  14. target_transform (callable, optional): A function/transform that takes in the target and transform it. Default: None.
  15. transforms (callable, optional): A function/transform that takes in both the image and target and transform them. Default: None.
  16. """
  17. cmap = np.array([
  18. (128, 128, 128),
  19. (128, 0, 0),
  20. (192, 192, 128),
  21. (128, 64, 128),
  22. (60, 40, 222),
  23. (128, 128, 0),
  24. (192, 128, 128),
  25. (64, 64, 128),
  26. (64, 0, 128),
  27. (64, 64, 0),
  28. (0, 128, 192),
  29. (0, 0, 0),
  30. ])
  31. def __init__(self,
  32. root,
  33. split='train',
  34. transform=None,
  35. target_transform=None,
  36. transforms=None):
  37. assert split in ('train', 'val', 'test', 'trainval')
  38. super( CamVid, self ).__init__(root=root, transforms=transforms, transform=transform, target_transform=target_transform)
  39. self.root = os.path.expanduser(root)
  40. self.split = split
  41. if split == 'trainval':
  42. self.images = glob(os.path.join(self.root, 'train', '*.png')) + glob(os.path.join(self.root, 'val', '*.png'))
  43. self.labels = glob(os.path.join(self.root, 'trainannot', '*.png')) + glob(os.path.join(self.root, 'valannot', '*.png'))
  44. else:
  45. self.images = glob(os.path.join(self.root, self.split, '*.png'))
  46. self.labels = glob(os.path.join(self.root, self.split+'annot', '*.png'))
  47. self.images.sort()
  48. self.labels.sort()
  49. def __getitem__(self, idx):
  50. """
  51. Args:
  52. - index (``int``): index of the item in the dataset
  53. Returns:
  54. A tuple of ``PIL.Image`` (image, label) where label is the ground-truth
  55. of the image.
  56. """
  57. img, label = Image.open(self.images[idx]), Image.open(self.labels[idx])
  58. if self.transforms is not None:
  59. img, label = self.transforms(img, label)
  60. label[label == 11] = 255 # ignore void
  61. return img, label.squeeze(0)
  62. def __len__(self):
  63. return len(self.images)
  64. @classmethod
  65. def decode_fn(cls, mask):
  66. """decode semantic mask to RGB image"""
  67. mask[mask == 255] = 11
  68. return cls.cmap[mask]

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)