You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imagenet.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # -*- coding: utf-8 -*-
  2. # BSD 3-Clause License
  3. #
  4. # Copyright (c) Soumith Chintala 2016,
  5. # All rights reserved.
  6. # ---------------------------------------------------------------------
  7. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  8. #
  9. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. #
  15. # This file has been modified by Megvii ("Megvii Modifications").
  16. # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
  17. # ---------------------------------------------------------------------
  18. import os
  19. import shutil
  20. from tqdm import tqdm
  21. from ....core.serialization import load, save
  22. from ....distributed.util import is_distributed
  23. from ....logger import get_logger
  24. from .folder import ImageFolder
  25. from .utils import _default_dataset_root, calculate_md5, untar, untargz
  26. logger = get_logger(__name__)
  27. class ImageNet(ImageFolder):
  28. r"""
  29. Load ImageNet from raw files or folder, expected folder looks like
  30. ${root}/
  31. | [REQUIRED TAR FILES]
  32. |- ILSVRC2012_img_train.tar
  33. |- ILSVRC2012_img_val.tar
  34. |- ILSVRC2012_devkit_t12.tar.gz
  35. | [OPTIONAL IMAGE FOLDERS]
  36. |- train/cls/xxx.${img_ext}
  37. |- val/cls/xxx.${img_ext}
  38. |- ILSVRC2012_devkit_t12/data/meta.mat
  39. |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt
  40. If the image folders don't exist, raw tar files are required to get extracted and processed.
  41. """
  42. raw_file_meta = {
  43. "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
  44. "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
  45. "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
  46. } # ImageNet raw files
  47. default_train_dir = "train"
  48. default_val_dir = "val"
  49. default_devkit_dir = "ILSVRC2012_devkit_t12"
  50. def __init__(self, root: str = None, train: bool = True, **kwargs):
  51. r"""
  52. initialization:
  53. * if ``root`` contains ``self.target_folder`` depent on ``train``:
  54. * initialize ImageFolder with target_folder
  55. * else:
  56. * if all raw files are in ``root``:
  57. * parse ``self.target_folder`` from raw files
  58. * initialize ImageFolder with ``self.target_folder``
  59. * else:
  60. * raise error
  61. :param root: root directory of imagenet data, if root is ``None``, used default_dataset_root
  62. :param train: if ``True``, load the train split, otherwise load the validation split
  63. """
  64. # process the root path
  65. if root is None:
  66. self.root = self._default_root
  67. else:
  68. self.root = root
  69. if not os.path.exists(self.root):
  70. raise FileNotFoundError("dir %s does not exist" % self.root)
  71. self.devkit_dir = os.path.join(self.root, self.default_devkit_dir)
  72. if not os.path.exists(self.devkit_dir):
  73. logger.warning("devkit directory %s does not exists" % self.devkit_dir)
  74. self._prepare_devkit()
  75. self.train = train
  76. if train:
  77. self.target_folder = os.path.join(self.root, self.default_train_dir)
  78. else:
  79. self.target_folder = os.path.join(self.root, self.default_val_dir)
  80. if not os.path.exists(self.target_folder):
  81. logger.warning(
  82. "expected image folder %s does not exist, try to load from raw file"
  83. % self.target_folder
  84. )
  85. if not self.check_raw_file():
  86. raise FileNotFoundError(
  87. "expected image folder %s does not exist, and raw files do not exist in %s"
  88. % (self.target_folder, self.root)
  89. )
  90. elif is_distributed():
  91. raise RuntimeError(
  92. "extracting raw file shouldn't be done in distributed mode, use single process instead"
  93. )
  94. else:
  95. self._prepare_train() if train else self._prepare_val()
  96. super().__init__(self.target_folder, **kwargs)
  97. @property
  98. def _default_root(self):
  99. return os.path.join(_default_dataset_root(), self.__class__.__name__)
  100. @property
  101. def valid_ground_truth(self):
  102. groud_truth_path = os.path.join(
  103. self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt"
  104. )
  105. if os.path.exists(groud_truth_path):
  106. with open(groud_truth_path, "r") as f:
  107. val_labels = f.readlines()
  108. return [int(val_label) for val_label in val_labels]
  109. else:
  110. raise FileNotFoundError(
  111. "valid ground truth file %s does not exist" % groud_truth_path
  112. )
  113. @property
  114. def meta(self):
  115. try:
  116. return load(os.path.join(self.devkit_dir, "meta.pkl"))
  117. except FileNotFoundError:
  118. import scipy.io as sio
  119. meta_path = os.path.join(self.devkit_dir, "data", "meta.mat")
  120. if not os.path.exists(meta_path):
  121. raise FileNotFoundError("meta file %s does not exist" % meta_path)
  122. meta = sio.loadmat(meta_path, squeeze_me=True)["synsets"]
  123. nums_children = list(zip(*meta))[4]
  124. meta = [
  125. meta[idx]
  126. for idx, num_children in enumerate(nums_children)
  127. if num_children == 0
  128. ]
  129. idcs, wnids, classes = list(zip(*meta))[:3]
  130. classes = [tuple(clss.split(", ")) for clss in classes]
  131. idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
  132. wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
  133. logger.info("saving cached meta file to %s", os.path.join(self.devkit_dir, "meta.pkl"))
  134. save((idx_to_wnid, wnid_to_classes), os.path.join(self.devkit_dir, "meta.pkl"))
  135. return idx_to_wnid, wnid_to_classes
  136. def check_raw_file(self) -> bool:
  137. return all(
  138. [
  139. os.path.exists(os.path.join(self.root, value[0]))
  140. for _, value in self.raw_file_meta.items()
  141. ]
  142. )
  143. def _organize_val_data(self):
  144. id2wnid = self.meta[0]
  145. val_idcs = self.valid_ground_truth
  146. val_wnids = [id2wnid[idx] for idx in val_idcs]
  147. val_images = sorted(
  148. [os.path.join(self.target_folder, image) for image in os.listdir(self.target_folder)]
  149. )
  150. logger.debug("mkdir for val set wnids")
  151. for wnid in set(val_wnids):
  152. os.makedirs(os.path.join(self.root, self.default_val_dir, wnid))
  153. logger.debug("mv val images into wnids dir")
  154. for wnid, img_file in tqdm(zip(val_wnids, val_images)):
  155. shutil.move(
  156. img_file,
  157. os.path.join(
  158. self.root, self.default_val_dir, wnid, os.path.basename(img_file)
  159. ),
  160. )
  161. def _prepare_val(self):
  162. assert not self.train
  163. raw_filename, checksum = self.raw_file_meta["val"]
  164. raw_file = os.path.join(self.root, raw_filename)
  165. logger.info("checksum valid tar file {} ..".format(raw_file))
  166. assert calculate_md5(raw_file) == checksum, \
  167. "checksum mismatch, {} may be damaged".format(raw_file)
  168. logger.info("extract valid tar file.. this may take 10-20 minutes")
  169. untar(os.path.join(self.root, raw_file), self.target_folder)
  170. self._organize_val_data()
  171. def _prepare_train(self):
  172. assert self.train
  173. raw_filename, checksum = self.raw_file_meta["train"]
  174. raw_file = os.path.join(self.root, raw_filename)
  175. logger.info("checksum train tar file {} ..".format(raw_file))
  176. assert calculate_md5(raw_file) == checksum, \
  177. "checksum mismatch, {} may be damaged".format(raw_file)
  178. logger.info("extract train tar file.. this may take several hours")
  179. untar(
  180. os.path.join(self.root, raw_file),
  181. self.target_folder,
  182. )
  183. paths = [
  184. os.path.join(self.target_folder, child_dir)
  185. for child_dir in os.listdir(self.target_folder)
  186. ]
  187. for path in tqdm(paths):
  188. untar(path, os.path.splitext(path)[0], remove=True)
  189. def _prepare_devkit(self):
  190. raw_filename, checksum = self.raw_file_meta["val"]
  191. raw_file = os.path.join(self.root, raw_filename)
  192. logger.info("checksum devkit tar file {} ..".format(raw_file))
  193. assert calculate_md5(raw_file) == checksum, \
  194. "checksum mismatch, {} may be damaged".format(raw_file)
  195. logger.info("extract devkit file..")
  196. untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0]))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)