You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stanford_dogs.py 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import os
  2. import numpy as np
  3. from PIL import Image
  4. from scipy.io import loadmat
  5. from torch.utils import data
  6. from .utils import download_url
  7. from shutil import move
  8. class StanfordDogs(data.Dataset):
  9. """Dataset for Stanford Dogs
  10. """
  11. urls = {"images.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar",
  12. "annotation.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar",
  13. "lists.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"}
  14. def __init__(self, root, split='train', download=False, transform=None, target_transform=None):
  15. self.root = os.path.abspath( os.path.expanduser(root) )
  16. self.split = split
  17. self.transform = transform
  18. self.target_transform = target_transform
  19. if download:
  20. self.download()
  21. list_file = os.path.join(self.root, self.split+'_list.mat')
  22. mat_file = loadmat(list_file)
  23. size = len(mat_file['file_list'])
  24. self.files = [str(mat_file['file_list'][i][0][0]) for i in range(size)]
  25. self.labels = np.array(
  26. [mat_file['labels'][i][0]-1 for i in range(size)])
  27. categories = os.listdir(os.path.join(self.root, 'Images'))
  28. categories.sort()
  29. self.object_categories = [c[10:] for c in categories]
  30. print('Stanford Dogs, Split: %s, Size: %d' %
  31. (self.split, self.__len__()))
  32. def __len__(self):
  33. return len(self.files)
  34. def __getitem__(self, idx):
  35. img = Image.open(os.path.join(self.root, 'Images',
  36. self.files[idx])).convert("RGB")
  37. lbl = self.labels[idx]
  38. if self.transform is not None:
  39. img = self.transform(img)
  40. if self.target_transform is not None:
  41. lbl = self.target_transform( lbl )
  42. return img, lbl
  43. def download(self):
  44. import tarfile
  45. os.makedirs(self.root, exist_ok=True)
  46. for fname, url in self.urls.items():
  47. if not os.path.isfile(os.path.join(self.root, fname)):
  48. download_url(url, self.root, fname)
  49. # extract file
  50. print("Extracting %s..." % fname)
  51. with tarfile.open(os.path.join(self.root, fname), "r") as tar:
  52. tar.extractall(path=self.root)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)