You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imagenet.py 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # -*- coding: utf-8 -*-
  2. # BSD 3-Clause License
  3. #
  4. # Copyright (c) Soumith Chintala 2016,
  5. # All rights reserved.
  6. # ---------------------------------------------------------------------
  7. #
  8. # This file has been modified by Megvii ("Megvii Modifications").
  9. # All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
  10. # ---------------------------------------------------------------------
  11. import os
  12. import shutil
  13. from tqdm import tqdm
  14. from ....distributed.group import is_distributed
  15. from ....logger import get_logger
  16. from ....serialization import load, save
  17. from .folder import ImageFolder
  18. from .utils import _default_dataset_root, calculate_md5, untar, untargz
  19. logger = get_logger(__name__)
  20. class ImageNet(ImageFolder):
  21. r"""Load ImageNet from raw files or folder. Expected folder looks like:
  22. .. code-block:: shell
  23. ${root}/
  24. | [REQUIRED TAR FILES]
  25. |- ILSVRC2012_img_train.tar
  26. |- ILSVRC2012_img_val.tar
  27. |- ILSVRC2012_devkit_t12.tar.gz
  28. | [OPTIONAL IMAGE FOLDERS]
  29. |- train/cls/xxx.${img_ext}
  30. |- val/cls/xxx.${img_ext}
  31. |- ILSVRC2012_devkit_t12/data/meta.mat
  32. |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt
  33. If the image folders don't exist, raw tar files are required to get extracted and processed.
  34. * if ``root`` contains ``self.target_folder`` depending on ``train``:
  35. * initialize ImageFolder with target_folder.
  36. * else:
  37. * if all raw files are in ``root``:
  38. * parse ``self.target_folder`` from raw files.
  39. * initialize ImageFolder with ``self.target_folder``.
  40. * else:
  41. * raise error.
  42. Args:
  43. root: root directory of imagenet data, if root is ``None``, use default_dataset_root.
  44. train: if ``True``, load the train split, otherwise load the validation split.
  45. """
  46. raw_file_meta = {
  47. "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
  48. "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
  49. "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
  50. } # ImageNet raw files
  51. default_train_dir = "train"
  52. default_val_dir = "val"
  53. default_devkit_dir = "ILSVRC2012_devkit_t12"
  54. def __init__(self, root: str = None, train: bool = True, **kwargs):
  55. # process the root path
  56. if root is None:
  57. self.root = self._default_root
  58. else:
  59. self.root = root
  60. if not os.path.exists(self.root):
  61. raise FileNotFoundError("dir %s does not exist" % self.root)
  62. self.devkit_dir = os.path.join(self.root, self.default_devkit_dir)
  63. if not os.path.exists(self.devkit_dir):
  64. logger.warning("devkit directory %s does not exists", self.devkit_dir)
  65. self._prepare_devkit()
  66. self.train = train
  67. if train:
  68. self.target_folder = os.path.join(self.root, self.default_train_dir)
  69. else:
  70. self.target_folder = os.path.join(self.root, self.default_val_dir)
  71. if not os.path.exists(self.target_folder):
  72. logger.warning(
  73. "expected image folder %s does not exist, try to load from raw file",
  74. self.target_folder,
  75. )
  76. if not self.check_raw_file():
  77. raise FileNotFoundError(
  78. "expected image folder %s does not exist, and raw files do not exist in %s"
  79. % (self.target_folder, self.root)
  80. )
  81. elif is_distributed():
  82. raise RuntimeError(
  83. "extracting raw file shouldn't be done in distributed mode, use single process instead"
  84. )
  85. elif train:
  86. self._prepare_train()
  87. else:
  88. self._prepare_val()
  89. super().__init__(self.target_folder, **kwargs)
  90. @property
  91. def _default_root(self):
  92. return os.path.join(_default_dataset_root(), self.__class__.__name__)
  93. @property
  94. def valid_ground_truth(self):
  95. groud_truth_path = os.path.join(
  96. self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt"
  97. )
  98. if os.path.exists(groud_truth_path):
  99. with open(groud_truth_path, "r") as f:
  100. val_labels = f.readlines()
  101. return [int(val_label) for val_label in val_labels]
  102. else:
  103. raise FileNotFoundError(
  104. "valid ground truth file %s does not exist" % groud_truth_path
  105. )
  106. @property
  107. def meta(self):
  108. try:
  109. return load(os.path.join(self.devkit_dir, "meta.pkl"))
  110. except FileNotFoundError:
  111. import scipy.io
  112. meta_path = os.path.join(self.devkit_dir, "data", "meta.mat")
  113. if not os.path.exists(meta_path):
  114. raise FileNotFoundError("meta file %s does not exist" % meta_path)
  115. meta = scipy.io.loadmat(meta_path, squeeze_me=True)["synsets"]
  116. nums_children = list(zip(*meta))[4]
  117. meta = [
  118. meta[idx]
  119. for idx, num_children in enumerate(nums_children)
  120. if num_children == 0
  121. ]
  122. idcs, wnids, classes = list(zip(*meta))[:3]
  123. classes = [tuple(clss.split(", ")) for clss in classes]
  124. idx_to_wnid = dict(zip(idcs, wnids))
  125. wnid_to_classes = dict(zip(wnids, classes))
  126. logger.info(
  127. "saving cached meta file to %s",
  128. os.path.join(self.devkit_dir, "meta.pkl"),
  129. )
  130. save(
  131. (idx_to_wnid, wnid_to_classes),
  132. os.path.join(self.devkit_dir, "meta.pkl"),
  133. )
  134. return idx_to_wnid, wnid_to_classes
  135. def check_raw_file(self) -> bool:
  136. return all(
  137. [
  138. os.path.exists(os.path.join(self.root, value[0]))
  139. for _, value in self.raw_file_meta.items()
  140. ]
  141. )
  142. def _organize_val_data(self):
  143. id2wnid = self.meta[0]
  144. val_idcs = self.valid_ground_truth
  145. val_wnids = [id2wnid[idx] for idx in val_idcs]
  146. val_images = sorted(
  147. [
  148. os.path.join(self.target_folder, image)
  149. for image in os.listdir(self.target_folder)
  150. ]
  151. )
  152. logger.debug("mkdir for val set wnids")
  153. for wnid in set(val_wnids):
  154. os.makedirs(os.path.join(self.root, self.default_val_dir, wnid))
  155. logger.debug("mv val images into wnids dir")
  156. for wnid, img_file in tqdm(zip(val_wnids, val_images)):
  157. shutil.move(
  158. img_file,
  159. os.path.join(
  160. self.root, self.default_val_dir, wnid, os.path.basename(img_file)
  161. ),
  162. )
  163. def _prepare_val(self):
  164. assert not self.train
  165. raw_filename, checksum = self.raw_file_meta["val"]
  166. raw_file = os.path.join(self.root, raw_filename)
  167. logger.info("checksum valid tar file %s ...", raw_file)
  168. assert (
  169. calculate_md5(raw_file) == checksum
  170. ), "checksum mismatch, {} may be damaged".format(raw_file)
  171. logger.info("extract valid tar file... this may take 10-20 minutes")
  172. untar(raw_file, self.target_folder)
  173. self._organize_val_data()
  174. def _prepare_train(self):
  175. assert self.train
  176. raw_filename, checksum = self.raw_file_meta["train"]
  177. raw_file = os.path.join(self.root, raw_filename)
  178. logger.info("checksum train tar file %s ...", raw_file)
  179. assert (
  180. calculate_md5(raw_file) == checksum
  181. ), "checksum mismatch, {} may be damaged".format(raw_file)
  182. logger.info("extract train tar file.. this may take several hours")
  183. untar(raw_file, self.target_folder)
  184. paths = [
  185. os.path.join(self.target_folder, child_dir)
  186. for child_dir in os.listdir(self.target_folder)
  187. ]
  188. for path in tqdm(paths):
  189. untar(path, os.path.splitext(path)[0], remove=True)
  190. def _prepare_devkit(self):
  191. raw_filename, checksum = self.raw_file_meta["devkit"]
  192. raw_file = os.path.join(self.root, raw_filename)
  193. logger.info("checksum devkit tar file %s ...", raw_file)
  194. assert (
  195. calculate_md5(raw_file) == checksum
  196. ), "checksum mismatch, {} may be damaged".format(raw_file)
  197. logger.info("extract devkit file..")
  198. untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0]))