You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imagenet.py 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # -*- coding: utf-8 -*-
  2. # BSD 3-Clause License
  3. #
  4. # Copyright (c) Soumith Chintala 2016,
  5. # All rights reserved.
  6. # ---------------------------------------------------------------------
  7. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  8. #
  9. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. #
  15. # This file has been modified by Megvii ("Megvii Modifications").
  16. # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
  17. # ---------------------------------------------------------------------
  18. import os
  19. import shutil
  20. from tqdm import tqdm
  21. from ....core.serialization import load, save
  22. from ....distributed.util import is_distributed
  23. from ....logger import get_logger
  24. from .folder import ImageFolder
  25. from .utils import _default_dataset_root, untar, untargz
  26. logger = get_logger(__name__)
  27. class ImageNet(ImageFolder):
  28. r"""
  29. Load ImageNet from raw files or folder, expected folder looks like
  30. raw files situation (optional):
  31. root/ILSVRC2012_img_train.tar
  32. root/ILSVRC2012_img_val.tar
  33. root/ILSVRC2012_devkit_t12.tar.gz
  34. image folder situation (required):
  35. root/train/cls/xxx.${img_ext}
  36. root/val/cls/xxx.${img_ext}
  37. root/ILSVRC2012_devkit_t12/data/meta.mat
  38. root/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt
  39. If the required folders don't exist, raw files are required to get extracted and processed.
  40. """
  41. raw_file_meta = {
  42. "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
  43. "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
  44. "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
  45. }
  46. """
  47. raw files of ImageNet (train, val, devkit)
  48. """
  49. default_train_dir = "train"
  50. """
  51. directory of train data
  52. """
  53. default_val_dir = "val"
  54. """
  55. directory of val data
  56. """
  57. default_devkit_dir = "ILSVRC2012_devkit_t12"
  58. """
  59. directory of devkit
  60. """
  61. def __init__(self, root: str = None, train: bool = True, **kwargs):
  62. r"""
  63. initialization:
  64. * if ``root`` contains ``self.target_folder`` depent on ``train``:
  65. * initialize ImageFolder with target_folder
  66. * else:
  67. * if all raw files are in ``root``:
  68. * parse ``self.target_folder`` from raw files
  69. * initialize ImageFolder with ``self.target_folder``
  70. * else:
  71. * raise error
  72. :param root: root directory of imagenet data, if root is ``None``, used default_dataset_root
  73. :param train: if ``True``, load the train split, otherwise load the validation split
  74. """
  75. # process the root path
  76. if root is None:
  77. self.root = self._default_root
  78. else:
  79. self.root = root
  80. self.devkit_dir = os.path.join(self.root, self.default_devkit_dir)
  81. if not os.path.exists(self.root):
  82. raise FileNotFoundError("dir %s does not exist" % self.root)
  83. if not os.path.exists(self.devkit_dir):
  84. logger.warning("devkit directory %s does not exists" % self.devkit_dir)
  85. if train:
  86. self.target_folder = os.path.join(self.root, self.default_train_dir)
  87. else:
  88. self.target_folder = os.path.join(self.root, self.default_val_dir)
  89. if not os.path.exists(self.target_folder):
  90. logger.warning(
  91. "expected image folder %s does not exist, try to load from raw file"
  92. % self.target_folder
  93. )
  94. if not self.check_raw_file():
  95. raise FileNotFoundError(
  96. "expected image folder %s does not exist, and raw files do not exist in %s"
  97. % (self.target_folder, self.root)
  98. )
  99. elif is_distributed():
  100. raise RuntimeError(
  101. "extracting raw file shouldn't be done in distributed mode, use single process instead"
  102. )
  103. else:
  104. self.parse(train)
  105. super().__init__(self.target_folder, **kwargs)
  106. @property
  107. def _default_root(self):
  108. return os.path.join(_default_dataset_root(), self.__class__.__name__)
  109. @property
  110. def valid_ground_truth(self):
  111. groud_truth_path = os.path.join(
  112. self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt"
  113. )
  114. if os.path.exists(groud_truth_path):
  115. with open(groud_truth_path, "r") as f:
  116. val_labels = f.readlines()
  117. return [int(val_label) for val_label in val_labels]
  118. else:
  119. raise FileNotFoundError(
  120. "valid ground truth file %s does not exist" % groud_truth_path
  121. )
  122. @property
  123. def meta(self):
  124. try:
  125. return load(os.path.join(self.devkit_dir, "meta.pkl"))
  126. except FileNotFoundError:
  127. import scipy.io as sio
  128. meta_path = os.path.join(self.devkit_dir, "data", "meta.mat")
  129. if not os.path.exists(meta_path):
  130. raise FileNotFoundError("meta file %s does not exist" % meta_path)
  131. meta = sio.loadmat(meta_path, squeeze_me=True)["synsets"]
  132. nums_children = list(zip(*meta))[4]
  133. meta = [
  134. meta[idx]
  135. for idx, num_children in enumerate(nums_children)
  136. if num_children == 0
  137. ]
  138. idcs, wnids, classes = list(zip(*meta))[:3]
  139. classes = [tuple(clss.split(", ")) for clss in classes]
  140. idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
  141. wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
  142. logger.info("saving cached meta file to %s", os.path.join(self.devkit_dir, "meta.pkl"))
  143. save((idx_to_wnid, wnid_to_classes), os.path.join(self.devkit_dir, "meta.pkl"))
  144. return idx_to_wnid, wnid_to_classes
  145. def check_raw_file(self) -> bool:
  146. return all(
  147. [
  148. os.path.exists(os.path.join(self.root, value[0]))
  149. for _, value in self.raw_file_meta.items()
  150. ]
  151. )
  152. def organize_val_data(self):
  153. id2wnid = self.meta[0]
  154. val_idcs = self.valid_ground_truth
  155. val_wnids = [id2wnid[idx] for idx in val_idcs]
  156. raw_val_dir = os.path.join(self.root, "ILSVRC2012_img_val")
  157. val_images = sorted(
  158. [os.path.join(raw_val_dir, image) for image in os.listdir(raw_val_dir)]
  159. )
  160. logger.debug("mkdir for val set wnids")
  161. for wnid in set(val_wnids):
  162. os.makedirs(os.path.join(self.root, self.default_val_dir, wnid))
  163. logger.debug("mv val images into wnids dir")
  164. for wnid, img_file in tqdm(zip(val_wnids, val_images)):
  165. shutil.move(
  166. img_file,
  167. os.path.join(
  168. self.root, self.default_val_dir, wnid, os.path.basename(img_file)
  169. ),
  170. )
  171. def parse(self, train):
  172. if train:
  173. logger.info("process train raw file.. this may take several hours")
  174. untar(
  175. os.path.join(self.root, self.raw_file_meta["train"][0]),
  176. self.target_folder,
  177. )
  178. paths = [
  179. os.path.join(self.target_folder, child_dir)
  180. for child_dir in os.listdir(self.target_folder)
  181. ]
  182. for path in tqdm(paths):
  183. untar(path, os.path.splitext(path)[0], remove=True)
  184. else:
  185. logger.info("process devkit file..")
  186. untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0]))
  187. logger.info("process valid raw file.. this may take 10-20 minutes")
  188. untar(os.path.join(self.root, self.raw_file_meta["val"][0]))
  189. self.organize_val_data()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)