You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cub200.py 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Modified from https://github.com/TDeVries/cub2011_dataset/blob/master/cub2011.py
  2. import os
  3. import pandas as pd
  4. from torchvision.datasets.folder import default_loader
  5. from .utils import download_url
  6. from torch.utils.data import Dataset
  7. import shutil
  8. class CUB200(Dataset):
  9. base_folder = 'images'
  10. url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
  11. filename = 'CUB_200_2011.tgz'
  12. tgz_md5 = '97eceeb196236b17998738112f37df78'
  13. def __init__(self, root, split='train', transform=None, target_transform=None, loader=default_loader, download=False):
  14. self.root = os.path.abspath( os.path.expanduser( root ) )
  15. self.transform = transform
  16. self.target_transform = target_transform
  17. self.loader = default_loader
  18. self.split = split
  19. if download:
  20. self.download()
  21. self._load_metadata()
  22. categories = os.listdir(os.path.join(
  23. self.root, 'CUB_200_2011', 'images'))
  24. categories.sort()
  25. self.object_categories = [c[4:] for c in categories]
  26. print('CUB200, Split: %s, Size: %d' % (self.split, self.__len__()))
  27. def _load_metadata(self):
  28. images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
  29. names=['img_id', 'filepath'])
  30. image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
  31. sep=' ', names=['img_id', 'target'], encoding='latin-1', engine='python')
  32. train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
  33. sep=' ', names=['img_id', 'is_training_img'], encoding='latin-1', engine='python')
  34. data = images.merge(image_class_labels, on='img_id')
  35. self.data = data.merge(train_test_split, on='img_id')
  36. if self.split == 'train':
  37. self.data = self.data[self.data.is_training_img == 1]
  38. else:
  39. self.data = self.data[self.data.is_training_img == 0]
  40. def download(self):
  41. import tarfile
  42. os.makedirs(self.root, exist_ok=True)
  43. if not os.path.isfile(os.path.join(self.root, self.filename)):
  44. download_url(self.url, self.root, self.filename)
  45. print("Extracting %s..." % self.filename)
  46. with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
  47. tar.extractall(path=self.root)
  48. def __len__(self):
  49. return len(self.data)
  50. def __getitem__(self, idx):
  51. sample = self.data.iloc[idx]
  52. path = os.path.join(self.root, 'CUB_200_2011',
  53. self.base_folder, sample.filepath)
  54. lbl = sample.target - 1
  55. img = self.loader(path)
  56. if self.transform is not None:
  57. img = self.transform(img)
  58. if self.target_transform is not None:
  59. lbl = self.target_transform(lbl)
  60. return img, lbl

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)