You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stanford_cars.py 2.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. import glob
  3. from PIL import Image
  4. import numpy as np
  5. from scipy.io import loadmat
  6. from torch.utils import data
  7. from .utils import download_url, mkdir
  8. from shutil import copyfile
  9. class StanfordCars(data.Dataset):
  10. """Dataset for Stanford Cars
  11. """
  12. urls = {'cars_train.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_train.tgz',
  13. 'cars_test.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_test.tgz',
  14. 'car_devkit.tgz': 'https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz',
  15. 'cars_test_annos_withlabels.mat': 'http://imagenet.stanford.edu/internal/car196/cars_test_annos_withlabels.mat'}
  16. def __init__(self, root, split='train', download=False, transform=None, target_transform=None):
  17. self.root = os.path.abspath( os.path.expanduser(root) )
  18. self.split = split
  19. self.transform = transform
  20. self.target_transform = target_transform
  21. if download:
  22. self.download()
  23. if self.split == 'train':
  24. annos = os.path.join(self.root, 'devkit', 'cars_train_annos.mat')
  25. else:
  26. annos = os.path.join(self.root, 'devkit',
  27. 'cars_test_annos_withlabels.mat')
  28. annos = loadmat(annos)
  29. size = len(annos['annotations'][0])
  30. self.files = glob.glob(os.path.join(
  31. self.root, 'cars_'+self.split, '*.jpg'))
  32. self.files.sort()
  33. self.labels = np.array([int(l[4])-1 for l in annos['annotations'][0]])
  34. lbl_annos = loadmat(os.path.join(self.root, 'devkit', 'cars_meta.mat'))
  35. self.object_categories = [str(c[0])
  36. for c in lbl_annos['class_names'][0]]
  37. print('Stanford Cars, Split: %s, Size: %d' %
  38. (self.split, self.__len__()))
  39. def __len__(self):
  40. return len(self.files)
  41. def __getitem__(self, idx):
  42. img = Image.open(os.path.join(self.root, 'Images',
  43. self.files[idx])).convert("RGB")
  44. lbl = self.labels[idx]
  45. if self.transform is not None:
  46. img = self.transform(img)
  47. if self.target_transform is not None:
  48. lbl = self.target_transform(lbl)
  49. return img, lbl
  50. def download(self):
  51. import tarfile
  52. mkdir(self.root)
  53. for fname, url in self.urls.items():
  54. if not os.path.isfile(os.path.join(self.root, fname)):
  55. download_url(url, self.root, fname)
  56. if fname.endswith('tgz'):
  57. print("Extracting %s..." % fname)
  58. with tarfile.open(os.path.join(self.root, fname), "r:gz") as tar:
  59. tar.extractall(path=self.root)
  60. copyfile(os.path.join(self.root, 'cars_test_annos_withlabels.mat'),
  61. os.path.join(self.root, 'devkit', 'cars_test_annos_withlabels.mat'))

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)