You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fgvc_aircraft.py 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #Modified from https://github.com/pytorch/vision/pull/467/files
  2. from __future__ import print_function
  3. import torch.utils.data as data
  4. from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader
  5. from PIL import Image
  6. import os
  7. import numpy as np
  8. from .utils import download_url, mkdir
  9. def make_dataset(dir, image_ids, targets):
  10. assert(len(image_ids) == len(targets))
  11. images = []
  12. dir = os.path.expanduser(dir)
  13. for i in range(len(image_ids)):
  14. item = (os.path.join(dir, 'fgvc-aircraft-2013b', 'data', 'images',
  15. '%s.jpg' % image_ids[i]), targets[i])
  16. images.append(item)
  17. return images
  18. def find_classes(classes_file):
  19. # read classes file, separating out image IDs and class names
  20. image_ids = []
  21. targets = []
  22. f = open(classes_file, 'r')
  23. for line in f:
  24. split_line = line.split(' ')
  25. image_ids.append(split_line[0])
  26. targets.append(' '.join(split_line[1:]))
  27. f.close()
  28. # index class names
  29. classes = np.unique(targets)
  30. class_to_idx = {classes[i]: i for i in range(len(classes))}
  31. targets = [class_to_idx[c] for c in targets]
  32. return (image_ids, targets, classes, class_to_idx)
  33. class FGVCAircraft(data.Dataset):
  34. """`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
  35. Args:
  36. root (string): Root directory path to dataset.
  37. class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
  38. to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
  39. transforms (callable, optional): A function/transforms that takes in a PIL image
  40. and returns a transformed version. E.g. ``transforms.RandomCrop``
  41. target_transform (callable, optional): A function/transforms that takes in the
  42. target and transforms it.
  43. loader (callable, optional): A function to load an image given its path.
  44. download (bool, optional): If true, downloads the dataset from the internet and
  45. puts it in the root directory. If dataset is already downloaded, it is not
  46. downloaded again.
  47. """
  48. url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
  49. class_types = ('variant', 'family', 'manufacturer')
  50. splits = ('train', 'val', 'trainval', 'test')
  51. def __init__(self, root, class_type='variant', split='train', transform=None,
  52. target_transform=None, loader=default_loader, download=False):
  53. if split not in self.splits:
  54. raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
  55. split, ', '.join(self.splits),
  56. ))
  57. if class_type not in self.class_types:
  58. raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
  59. class_type, ', '.join(self.class_types),
  60. ))
  61. self.root = root
  62. self.class_type = class_type
  63. self.split = split
  64. self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data',
  65. 'images_%s_%s.txt' % (self.class_type, self.split))
  66. if download:
  67. self.download()
  68. (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
  69. samples = make_dataset(self.root, image_ids, targets)
  70. self.transform = transform
  71. self.target_transform = target_transform
  72. self.loader = loader
  73. self.samples = samples
  74. self.classes = classes
  75. self.class_to_idx = class_to_idx
  76. with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'variants.txt')) as f:
  77. self.object_categories = [
  78. line.strip('\n') for line in f.readlines()]
  79. print('FGVC-Aircraft, Split: %s, Size: %d' % (self.split, self.__len__()))
  80. def __getitem__(self, index):
  81. """
  82. Args:
  83. index (int): Index
  84. Returns:
  85. tuple: (sample, target) where target is class_index of the target class.
  86. """
  87. path, target = self.samples[index]
  88. sample = self.loader(path)
  89. if self.transform is not None:
  90. sample = self.transform(sample)
  91. if self.target_transform is not None:
  92. target = self.target_transform(target)
  93. return sample, target
  94. def __len__(self):
  95. return len(self.samples)
  96. def __repr__(self):
  97. fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
  98. fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
  99. fmt_str += ' Root Location: {}\n'.format(self.root)
  100. tmp = ' Transforms (if any): '
  101. fmt_str += '{0}{1}\n'.format(
  102. tmp, self.transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  103. tmp = ' Target Transforms (if any): '
  104. fmt_str += '{0}{1}'.format(
  105. tmp, self.target_transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  106. return fmt_str
  107. def _check_exists(self):
  108. return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
  109. os.path.exists(self.classes_file)
  110. def download(self):
  111. """Download the FGVC-Aircraft data if it doesn't exist already."""
  112. from six.moves import urllib
  113. import tarfile
  114. mkdir(self.root)
  115. fpath = os.path.join(self.root, 'fgvc-aircraft-2013b.tar.gz')
  116. if not os.path.isfile(fpath):
  117. download_url(self.url, self.root, 'fgvc-aircraft-2013b.tar.gz')
  118. print("Extracting fgvc-aircraft-2013b.tar.gz...")
  119. with tarfile.open(fpath, "r:gz") as tar:
  120. tar.extractall(path=self.root)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)