You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nyu.py 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Modified from https://github.com/VainF/nyuv2-python-toolkit
  2. import os
  3. import torch
  4. import torch.utils.data as data
  5. from PIL import Image
  6. from scipy.io import loadmat
  7. import numpy as np
  8. import glob
  9. from torchvision import transforms
  10. from torchvision.datasets import VisionDataset
  11. import random
  12. from .utils import colormap
  13. class NYUv2(VisionDataset):
  14. """NYUv2 dataset
  15. See https://github.com/VainF/nyuv2-python-toolkit for more details.
  16. Args:
  17. root (string): Root directory path.
  18. split (string, optional): 'train' for training set, and 'test' for test set. Default: 'train'.
  19. target_type (string, optional): Type of target to use, ``semantic``, ``depth`` or ``normal``.
  20. num_classes (int, optional): The number of classes, must be 40 or 13. Default:13.
  21. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version.
  22. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
  23. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version.
  24. """
  25. cmap = colormap()
  26. def __init__(self,
  27. root,
  28. split='train',
  29. target_type='semantic',
  30. num_classes=13,
  31. transforms=None,
  32. transform=None,
  33. target_transform=None):
  34. super( NYUv2, self ).__init__(root, transforms=transforms, transform=transform, target_transform=target_transform)
  35. assert(split in ('train', 'test'))
  36. self.root = root
  37. self.split = split
  38. self.target_type = target_type
  39. self.num_classes = num_classes
  40. split_mat = loadmat(os.path.join(self.root, 'splits.mat'))
  41. idxs = split_mat[self.split+'Ndxs'].reshape(-1) - 1
  42. img_names = os.listdir( os.path.join(self.root, 'image', self.split) )
  43. img_names.sort()
  44. images_dir = os.path.join(self.root, 'image', self.split)
  45. self.images = [os.path.join(images_dir, name) for name in img_names]
  46. self._is_depth = False
  47. if self.target_type=='semantic':
  48. semantic_dir = os.path.join(self.root, 'seg%d'%self.num_classes, self.split)
  49. self.labels = [os.path.join(semantic_dir, name) for name in img_names]
  50. self.targets = self.labels
  51. if self.target_type=='depth':
  52. depth_dir = os.path.join(self.root, 'depth', self.split)
  53. self.depths = [os.path.join(depth_dir, name) for name in img_names]
  54. self.targets = self.depths
  55. self._is_depth = True
  56. if self.target_type=='normal':
  57. normal_dir = os.path.join(self.root, 'normal', self.split)
  58. self.normals = [os.path.join(normal_dir, name) for name in img_names]
  59. self.targets = self.normals
  60. def __getitem__(self, idx):
  61. image = Image.open(self.images[idx])
  62. target = Image.open(self.targets[idx])
  63. if self.transforms is not None:
  64. image, target = self.transforms( image, target )
  65. return image, target
  66. def __len__(self):
  67. return len(self.images)
  68. @classmethod
  69. def decode_fn(cls, mask: np.ndarray):
  70. """decode semantic mask to RGB image"""
  71. mask = mask.astype('uint8') + 1 # 255 => 0
  72. return cls.cmap[mask]

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)