You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imagenet.py 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # -*- coding: utf-8 -*-
  2. # BSD 3-Clause License
  3. #
  4. # Copyright (c) Soumith Chintala 2016,
  5. # All rights reserved.
  6. # ---------------------------------------------------------------------
  7. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  8. #
  9. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. #
  15. # This file has been modified by Megvii ("Megvii Modifications").
  16. # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
  17. # ---------------------------------------------------------------------
  18. import os
  19. import shutil
  20. from tqdm import tqdm
  21. from ....core.serialization import load, save
  22. from ....distributed.util import is_distributed
  23. from ....logger import get_logger
  24. from .folder import ImageFolder
  25. from .utils import _default_dataset_root, calculate_md5, untar, untargz
  26. logger = get_logger(__name__)
  27. class ImageNet(ImageFolder):
  28. r"""
  29. Load ImageNet from raw files or folder, expected folder looks like
  30. .. code-block:: bash
  31. ${root}/
  32. | [REQUIRED TAR FILES]
  33. |- ILSVRC2012_img_train.tar
  34. |- ILSVRC2012_img_val.tar
  35. |- ILSVRC2012_devkit_t12.tar.gz
  36. | [OPTIONAL IMAGE FOLDERS]
  37. |- train/cls/xxx.${img_ext}
  38. |- val/cls/xxx.${img_ext}
  39. |- ILSVRC2012_devkit_t12/data/meta.mat
  40. |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt
  41. If the image folders don't exist, raw tar files are required to get extracted and processed.
  42. """
  43. raw_file_meta = {
  44. "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
  45. "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
  46. "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
  47. } # ImageNet raw files
  48. default_train_dir = "train"
  49. default_val_dir = "val"
  50. default_devkit_dir = "ILSVRC2012_devkit_t12"
  51. def __init__(self, root: str = None, train: bool = True, **kwargs):
  52. r"""
  53. initialization:
  54. * if ``root`` contains ``self.target_folder`` depent on ``train``:
  55. * initialize ImageFolder with target_folder
  56. * else:
  57. * if all raw files are in ``root``:
  58. * parse ``self.target_folder`` from raw files
  59. * initialize ImageFolder with ``self.target_folder``
  60. * else:
  61. * raise error
  62. :param root: root directory of imagenet data, if root is ``None``, used default_dataset_root
  63. :param train: if ``True``, load the train split, otherwise load the validation split
  64. """
  65. # process the root path
  66. if root is None:
  67. self.root = self._default_root
  68. else:
  69. self.root = root
  70. if not os.path.exists(self.root):
  71. raise FileNotFoundError("dir %s does not exist" % self.root)
  72. self.devkit_dir = os.path.join(self.root, self.default_devkit_dir)
  73. if not os.path.exists(self.devkit_dir):
  74. logger.warning("devkit directory %s does not exists" % self.devkit_dir)
  75. self._prepare_devkit()
  76. self.train = train
  77. if train:
  78. self.target_folder = os.path.join(self.root, self.default_train_dir)
  79. else:
  80. self.target_folder = os.path.join(self.root, self.default_val_dir)
  81. if not os.path.exists(self.target_folder):
  82. logger.warning(
  83. "expected image folder %s does not exist, try to load from raw file"
  84. % self.target_folder
  85. )
  86. if not self.check_raw_file():
  87. raise FileNotFoundError(
  88. "expected image folder %s does not exist, and raw files do not exist in %s"
  89. % (self.target_folder, self.root)
  90. )
  91. elif is_distributed():
  92. raise RuntimeError(
  93. "extracting raw file shouldn't be done in distributed mode, use single process instead"
  94. )
  95. else:
  96. self._prepare_train() if train else self._prepare_val()
  97. super().__init__(self.target_folder, **kwargs)
  98. @property
  99. def _default_root(self):
  100. return os.path.join(_default_dataset_root(), self.__class__.__name__)
  101. @property
  102. def valid_ground_truth(self):
  103. groud_truth_path = os.path.join(
  104. self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt"
  105. )
  106. if os.path.exists(groud_truth_path):
  107. with open(groud_truth_path, "r") as f:
  108. val_labels = f.readlines()
  109. return [int(val_label) for val_label in val_labels]
  110. else:
  111. raise FileNotFoundError(
  112. "valid ground truth file %s does not exist" % groud_truth_path
  113. )
  114. @property
  115. def meta(self):
  116. try:
  117. return load(os.path.join(self.devkit_dir, "meta.pkl"))
  118. except FileNotFoundError:
  119. import scipy.io as sio
  120. meta_path = os.path.join(self.devkit_dir, "data", "meta.mat")
  121. if not os.path.exists(meta_path):
  122. raise FileNotFoundError("meta file %s does not exist" % meta_path)
  123. meta = sio.loadmat(meta_path, squeeze_me=True)["synsets"]
  124. nums_children = list(zip(*meta))[4]
  125. meta = [
  126. meta[idx]
  127. for idx, num_children in enumerate(nums_children)
  128. if num_children == 0
  129. ]
  130. idcs, wnids, classes = list(zip(*meta))[:3]
  131. classes = [tuple(clss.split(", ")) for clss in classes]
  132. idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
  133. wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
  134. logger.info(
  135. "saving cached meta file to %s",
  136. os.path.join(self.devkit_dir, "meta.pkl"),
  137. )
  138. save(
  139. (idx_to_wnid, wnid_to_classes),
  140. os.path.join(self.devkit_dir, "meta.pkl"),
  141. )
  142. return idx_to_wnid, wnid_to_classes
  143. def check_raw_file(self) -> bool:
  144. return all(
  145. [
  146. os.path.exists(os.path.join(self.root, value[0]))
  147. for _, value in self.raw_file_meta.items()
  148. ]
  149. )
  150. def _organize_val_data(self):
  151. id2wnid = self.meta[0]
  152. val_idcs = self.valid_ground_truth
  153. val_wnids = [id2wnid[idx] for idx in val_idcs]
  154. val_images = sorted(
  155. [
  156. os.path.join(self.target_folder, image)
  157. for image in os.listdir(self.target_folder)
  158. ]
  159. )
  160. logger.debug("mkdir for val set wnids")
  161. for wnid in set(val_wnids):
  162. os.makedirs(os.path.join(self.root, self.default_val_dir, wnid))
  163. logger.debug("mv val images into wnids dir")
  164. for wnid, img_file in tqdm(zip(val_wnids, val_images)):
  165. shutil.move(
  166. img_file,
  167. os.path.join(
  168. self.root, self.default_val_dir, wnid, os.path.basename(img_file)
  169. ),
  170. )
  171. def _prepare_val(self):
  172. assert not self.train
  173. raw_filename, checksum = self.raw_file_meta["val"]
  174. raw_file = os.path.join(self.root, raw_filename)
  175. logger.info("checksum valid tar file {} ..".format(raw_file))
  176. assert (
  177. calculate_md5(raw_file) == checksum
  178. ), "checksum mismatch, {} may be damaged".format(raw_file)
  179. logger.info("extract valid tar file.. this may take 10-20 minutes")
  180. untar(os.path.join(self.root, raw_file), self.target_folder)
  181. self._organize_val_data()
  182. def _prepare_train(self):
  183. assert self.train
  184. raw_filename, checksum = self.raw_file_meta["train"]
  185. raw_file = os.path.join(self.root, raw_filename)
  186. logger.info("checksum train tar file {} ..".format(raw_file))
  187. assert (
  188. calculate_md5(raw_file) == checksum
  189. ), "checksum mismatch, {} may be damaged".format(raw_file)
  190. logger.info("extract train tar file.. this may take several hours")
  191. untar(
  192. os.path.join(self.root, raw_file), self.target_folder,
  193. )
  194. paths = [
  195. os.path.join(self.target_folder, child_dir)
  196. for child_dir in os.listdir(self.target_folder)
  197. ]
  198. for path in tqdm(paths):
  199. untar(path, os.path.splitext(path)[0], remove=True)
  200. def _prepare_devkit(self):
  201. raw_filename, checksum = self.raw_file_meta["devkit"]
  202. raw_file = os.path.join(self.root, raw_filename)
  203. logger.info("checksum devkit tar file {} ..".format(raw_file))
  204. assert (
  205. calculate_md5(raw_file) == checksum
  206. ), "checksum mismatch, {} may be damaged".format(raw_file)
  207. logger.info("extract devkit file..")
  208. untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0]))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)