You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cifar.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import pickle
  4. import tarfile
  5. from typing import Tuple
  6. import numpy as np
  7. from ....logger import get_logger
  8. from .meta_vision import VisionDataset
  9. from .utils import _default_dataset_root, load_raw_data_from_url
  10. logger = get_logger(__name__)
  11. class CIFAR10(VisionDataset):
  12. r""":class:`~.Dataset` for CIFAR10 meta data."""
  13. url_path = "http://www.cs.utoronto.ca/~kriz/"
  14. raw_file_name = "cifar-10-python.tar.gz"
  15. raw_file_md5 = "c58f30108f718f92721af3b95e74349a"
  16. raw_file_dir = "cifar-10-batches-py"
  17. train_batch = [
  18. "data_batch_1",
  19. "data_batch_2",
  20. "data_batch_3",
  21. "data_batch_4",
  22. "data_batch_5",
  23. ]
  24. test_batch = ["test_batch"]
  25. meta_info = {"name": "batches.meta"}
  26. def __init__(
  27. self,
  28. root: str = None,
  29. train: bool = True,
  30. download: bool = True,
  31. timeout: int = 500,
  32. ):
  33. super().__init__(root, order=("image", "image_category"))
  34. self.timeout = timeout
  35. # process the root path
  36. if root is None:
  37. self.root = self._default_root
  38. if not os.path.exists(self.root):
  39. os.makedirs(self.root)
  40. else:
  41. self.root = root
  42. if not os.path.exists(self.root):
  43. if download:
  44. logger.debug(
  45. "dir %s does not exist, will be automatically created",
  46. self.root,
  47. )
  48. os.makedirs(self.root)
  49. else:
  50. raise ValueError("dir %s does not exist" % self.root)
  51. self.target_file = os.path.join(self.root, self.raw_file_dir)
  52. # check existence of target pickle dir, if exists load the
  53. # pickle file no matter what download is set
  54. if os.path.exists(self.target_file):
  55. if train:
  56. self.arrays = self.bytes2array(self.train_batch)
  57. else:
  58. self.arrays = self.bytes2array(self.test_batch)
  59. else:
  60. if download:
  61. self.download()
  62. if train:
  63. self.arrays = self.bytes2array(self.train_batch)
  64. else:
  65. self.arrays = self.bytes2array(self.test_batch)
  66. else:
  67. raise ValueError(
  68. "dir does not contain target file %s, please set download=True"
  69. % (self.target_file)
  70. )
  71. def __getitem__(self, index: int) -> Tuple:
  72. return tuple(array[index] for array in self.arrays)
  73. def __len__(self) -> int:
  74. return len(self.arrays[0])
  75. @property
  76. def _default_root(self):
  77. return os.path.join(_default_dataset_root(), self.__class__.__name__)
  78. @property
  79. def meta(self):
  80. meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"])
  81. with open(meta_path, "rb") as f:
  82. meta = pickle.load(f, encoding="bytes")
  83. return meta
  84. def download(self):
  85. url = self.url_path + self.raw_file_name
  86. load_raw_data_from_url(url, self.raw_file_name, self.raw_file_md5, self.root)
  87. self.process()
  88. def untar(self, file_path, dirs):
  89. assert file_path.endswith(".tar.gz")
  90. logger.debug("untar file %s to %s", file_path, dirs)
  91. t = tarfile.open(file_path)
  92. t.extractall(path=dirs)
  93. def bytes2array(self, filenames):
  94. data = []
  95. label = []
  96. for filename in filenames:
  97. path = os.path.join(self.root, self.raw_file_dir, filename)
  98. logger.debug("unpickle file %s", path)
  99. with open(path, "rb") as fo:
  100. dic = pickle.load(fo, encoding="bytes")
  101. batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
  102. data.extend(list(batch_data[..., [2, 1, 0]]))
  103. label.extend(dic[b"labels"])
  104. label = np.array(label, dtype=np.int32)
  105. return (data, label)
  106. def process(self):
  107. logger.info("process raw data ...")
  108. self.untar(os.path.join(self.root, self.raw_file_name), self.root)
  109. class CIFAR100(CIFAR10):
  110. r""":class:`~.Dataset` for CIFAR100 meta data."""
  111. url_path = "http://www.cs.utoronto.ca/~kriz/"
  112. raw_file_name = "cifar-100-python.tar.gz"
  113. raw_file_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
  114. raw_file_dir = "cifar-100-python"
  115. train_batch = ["train"]
  116. test_batch = ["test"]
  117. meta_info = {"name": "meta"}
  118. @property
  119. def meta(self):
  120. meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"])
  121. with open(meta_path, "rb") as f:
  122. meta = pickle.load(f, encoding="bytes")
  123. return meta
  124. def bytes2array(self, filenames):
  125. data = []
  126. fine_label = []
  127. coarse_label = []
  128. for filename in filenames:
  129. path = os.path.join(self.root, self.raw_file_dir, filename)
  130. logger.debug("unpickle file %s", path)
  131. with open(path, "rb") as fo:
  132. dic = pickle.load(fo, encoding="bytes")
  133. batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
  134. data.extend(list(batch_data[..., [2, 1, 0]]))
  135. fine_label.extend(dic[b"fine_labels"])
  136. coarse_label.extend(dic[b"coarse_labels"])
  137. fine_label = np.array(fine_label, dtype=np.int32)
  138. coarse_label = np.array(coarse_label, dtype=np.int32)
  139. return data, fine_label, coarse_label