You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imagenet.py 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # -*- coding: utf-8 -*-
  2. # BSD 3-Clause License
  3. #
  4. # Copyright (c) Soumith Chintala 2016,
  5. # All rights reserved.
  6. # ---------------------------------------------------------------------
  7. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  8. #
  9. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. #
  15. # This file has been modified by Megvii ("Megvii Modifications").
  16. # All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
  17. # ---------------------------------------------------------------------
  18. import os
  19. import shutil
  20. from tqdm import tqdm
  21. from ....distributed.group import is_distributed
  22. from ....logger import get_logger
  23. from ....serialization import load, save
  24. from .folder import ImageFolder
  25. from .utils import _default_dataset_root, calculate_md5, untar, untargz
  26. logger = get_logger(__name__)
  27. class ImageNet(ImageFolder):
  28. r"""Load ImageNet from raw files or folder. Expected folder looks like:
  29. .. code-block:: shell
  30. ${root}/
  31. | [REQUIRED TAR FILES]
  32. |- ILSVRC2012_img_train.tar
  33. |- ILSVRC2012_img_val.tar
  34. |- ILSVRC2012_devkit_t12.tar.gz
  35. | [OPTIONAL IMAGE FOLDERS]
  36. |- train/cls/xxx.${img_ext}
  37. |- val/cls/xxx.${img_ext}
  38. |- ILSVRC2012_devkit_t12/data/meta.mat
  39. |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt
  40. If the image folders don't exist, raw tar files are required to get extracted and processed.
  41. * if ``root`` contains ``self.target_folder`` depending on ``train``:
  42. * initialize ImageFolder with target_folder.
  43. * else:
  44. * if all raw files are in ``root``:
  45. * parse ``self.target_folder`` from raw files.
  46. * initialize ImageFolder with ``self.target_folder``.
  47. * else:
  48. * raise error.
  49. Args:
  50. root: root directory of imagenet data, if root is ``None``, use default_dataset_root.
  51. train: if ``True``, load the train split, otherwise load the validation split.
  52. """
  53. raw_file_meta = {
  54. "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
  55. "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
  56. "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
  57. } # ImageNet raw files
  58. default_train_dir = "train"
  59. default_val_dir = "val"
  60. default_devkit_dir = "ILSVRC2012_devkit_t12"
  61. def __init__(self, root: str = None, train: bool = True, **kwargs):
  62. # process the root path
  63. if root is None:
  64. self.root = self._default_root
  65. else:
  66. self.root = root
  67. if not os.path.exists(self.root):
  68. raise FileNotFoundError("dir %s does not exist" % self.root)
  69. self.devkit_dir = os.path.join(self.root, self.default_devkit_dir)
  70. if not os.path.exists(self.devkit_dir):
  71. logger.warning("devkit directory %s does not exists", self.devkit_dir)
  72. self._prepare_devkit()
  73. self.train = train
  74. if train:
  75. self.target_folder = os.path.join(self.root, self.default_train_dir)
  76. else:
  77. self.target_folder = os.path.join(self.root, self.default_val_dir)
  78. if not os.path.exists(self.target_folder):
  79. logger.warning(
  80. "expected image folder %s does not exist, try to load from raw file",
  81. self.target_folder,
  82. )
  83. if not self.check_raw_file():
  84. raise FileNotFoundError(
  85. "expected image folder %s does not exist, and raw files do not exist in %s"
  86. % (self.target_folder, self.root)
  87. )
  88. elif is_distributed():
  89. raise RuntimeError(
  90. "extracting raw file shouldn't be done in distributed mode, use single process instead"
  91. )
  92. elif train:
  93. self._prepare_train()
  94. else:
  95. self._prepare_val()
  96. super().__init__(self.target_folder, **kwargs)
  97. @property
  98. def _default_root(self):
  99. return os.path.join(_default_dataset_root(), self.__class__.__name__)
  100. @property
  101. def valid_ground_truth(self):
  102. groud_truth_path = os.path.join(
  103. self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt"
  104. )
  105. if os.path.exists(groud_truth_path):
  106. with open(groud_truth_path, "r") as f:
  107. val_labels = f.readlines()
  108. return [int(val_label) for val_label in val_labels]
  109. else:
  110. raise FileNotFoundError(
  111. "valid ground truth file %s does not exist" % groud_truth_path
  112. )
  113. @property
  114. def meta(self):
  115. try:
  116. return load(os.path.join(self.devkit_dir, "meta.pkl"))
  117. except FileNotFoundError:
  118. import scipy.io
  119. meta_path = os.path.join(self.devkit_dir, "data", "meta.mat")
  120. if not os.path.exists(meta_path):
  121. raise FileNotFoundError("meta file %s does not exist" % meta_path)
  122. meta = scipy.io.loadmat(meta_path, squeeze_me=True)["synsets"]
  123. nums_children = list(zip(*meta))[4]
  124. meta = [
  125. meta[idx]
  126. for idx, num_children in enumerate(nums_children)
  127. if num_children == 0
  128. ]
  129. idcs, wnids, classes = list(zip(*meta))[:3]
  130. classes = [tuple(clss.split(", ")) for clss in classes]
  131. idx_to_wnid = dict(zip(idcs, wnids))
  132. wnid_to_classes = dict(zip(wnids, classes))
  133. logger.info(
  134. "saving cached meta file to %s",
  135. os.path.join(self.devkit_dir, "meta.pkl"),
  136. )
  137. save(
  138. (idx_to_wnid, wnid_to_classes),
  139. os.path.join(self.devkit_dir, "meta.pkl"),
  140. )
  141. return idx_to_wnid, wnid_to_classes
  142. def check_raw_file(self) -> bool:
  143. return all(
  144. [
  145. os.path.exists(os.path.join(self.root, value[0]))
  146. for _, value in self.raw_file_meta.items()
  147. ]
  148. )
  149. def _organize_val_data(self):
  150. id2wnid = self.meta[0]
  151. val_idcs = self.valid_ground_truth
  152. val_wnids = [id2wnid[idx] for idx in val_idcs]
  153. val_images = sorted(
  154. [
  155. os.path.join(self.target_folder, image)
  156. for image in os.listdir(self.target_folder)
  157. ]
  158. )
  159. logger.debug("mkdir for val set wnids")
  160. for wnid in set(val_wnids):
  161. os.makedirs(os.path.join(self.root, self.default_val_dir, wnid))
  162. logger.debug("mv val images into wnids dir")
  163. for wnid, img_file in tqdm(zip(val_wnids, val_images)):
  164. shutil.move(
  165. img_file,
  166. os.path.join(
  167. self.root, self.default_val_dir, wnid, os.path.basename(img_file)
  168. ),
  169. )
  170. def _prepare_val(self):
  171. assert not self.train
  172. raw_filename, checksum = self.raw_file_meta["val"]
  173. raw_file = os.path.join(self.root, raw_filename)
  174. logger.info("checksum valid tar file %s ...", raw_file)
  175. assert (
  176. calculate_md5(raw_file) == checksum
  177. ), "checksum mismatch, {} may be damaged".format(raw_file)
  178. logger.info("extract valid tar file... this may take 10-20 minutes")
  179. untar(os.path.join(self.root, raw_file), self.target_folder)
  180. self._organize_val_data()
  181. def _prepare_train(self):
  182. assert self.train
  183. raw_filename, checksum = self.raw_file_meta["train"]
  184. raw_file = os.path.join(self.root, raw_filename)
  185. logger.info("checksum train tar file %s ...", raw_file)
  186. assert (
  187. calculate_md5(raw_file) == checksum
  188. ), "checksum mismatch, {} may be damaged".format(raw_file)
  189. logger.info("extract train tar file.. this may take several hours")
  190. untar(
  191. os.path.join(self.root, raw_file), self.target_folder,
  192. )
  193. paths = [
  194. os.path.join(self.target_folder, child_dir)
  195. for child_dir in os.listdir(self.target_folder)
  196. ]
  197. for path in tqdm(paths):
  198. untar(path, os.path.splitext(path)[0], remove=True)
  199. def _prepare_devkit(self):
  200. raw_filename, checksum = self.raw_file_meta["devkit"]
  201. raw_file = os.path.join(self.root, raw_filename)
  202. logger.info("checksum devkit tar file %s ...", raw_file)
  203. assert (
  204. calculate_md5(raw_file) == checksum
  205. ), "checksum mismatch, {} may be damaged".format(raw_file)
  206. logger.info("extract devkit file..")
  207. untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0]))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台