You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imagenet.py 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # -*- coding: utf-8 -*-
  2. # BSD 3-Clause License
  3. #
  4. # Copyright (c) Soumith Chintala 2016,
  5. # All rights reserved.
  6. # ---------------------------------------------------------------------
  7. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  8. #
  9. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. #
  15. # This file has been modified by Megvii ("Megvii Modifications").
  16. # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
  17. # ---------------------------------------------------------------------
  18. import os
  19. import shutil
  20. from tqdm import tqdm
  21. from ....distributed.group import is_distributed
  22. from ....logger import get_logger
  23. from ....serialization import load, save
  24. from .folder import ImageFolder
  25. from .utils import _default_dataset_root, calculate_md5, untar, untargz
  26. logger = get_logger(__name__)
  27. class ImageNet(ImageFolder):
  28. r"""
  29. Load ImageNet from raw files or folder. Expected folder looks like:
  30. .. code-block:: bash
  31. ${root}/
  32. | [REQUIRED TAR FILES]
  33. |- ILSVRC2012_img_train.tar
  34. |- ILSVRC2012_img_val.tar
  35. |- ILSVRC2012_devkit_t12.tar.gz
  36. | [OPTIONAL IMAGE FOLDERS]
  37. |- train/cls/xxx.${img_ext}
  38. |- val/cls/xxx.${img_ext}
  39. |- ILSVRC2012_devkit_t12/data/meta.mat
  40. |- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt
  41. If the image folders don't exist, raw tar files are required to get extracted and processed.
  42. """
  43. raw_file_meta = {
  44. "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
  45. "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
  46. "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
  47. } # ImageNet raw files
  48. default_train_dir = "train"
  49. default_val_dir = "val"
  50. default_devkit_dir = "ILSVRC2012_devkit_t12"
  51. def __init__(self, root: str = None, train: bool = True, **kwargs):
  52. r"""
  53. Initialization:
  54. * if ``root`` contains ``self.target_folder`` depending on ``train``:
  55. * initialize ImageFolder with target_folder.
  56. * else:
  57. * if all raw files are in ``root``:
  58. * parse ``self.target_folder`` from raw files.
  59. * initialize ImageFolder with ``self.target_folder``.
  60. * else:
  61. * raise error.
  62. :param root: root directory of imagenet data, if root is ``None``, use default_dataset_root.
  63. :param train: if ``True``, load the train split, otherwise load the validation split.
  64. """
  65. # process the root path
  66. if root is None:
  67. self.root = self._default_root
  68. else:
  69. self.root = root
  70. if not os.path.exists(self.root):
  71. raise FileNotFoundError("dir %s does not exist" % self.root)
  72. self.devkit_dir = os.path.join(self.root, self.default_devkit_dir)
  73. if not os.path.exists(self.devkit_dir):
  74. logger.warning("devkit directory %s does not exists", self.devkit_dir)
  75. self._prepare_devkit()
  76. self.train = train
  77. if train:
  78. self.target_folder = os.path.join(self.root, self.default_train_dir)
  79. else:
  80. self.target_folder = os.path.join(self.root, self.default_val_dir)
  81. if not os.path.exists(self.target_folder):
  82. logger.warning(
  83. "expected image folder %s does not exist, try to load from raw file",
  84. self.target_folder,
  85. )
  86. if not self.check_raw_file():
  87. raise FileNotFoundError(
  88. "expected image folder %s does not exist, and raw files do not exist in %s"
  89. % (self.target_folder, self.root)
  90. )
  91. elif is_distributed():
  92. raise RuntimeError(
  93. "extracting raw file shouldn't be done in distributed mode, use single process instead"
  94. )
  95. elif train:
  96. self._prepare_train()
  97. else:
  98. self._prepare_val()
  99. super().__init__(self.target_folder, **kwargs)
  100. @property
  101. def _default_root(self):
  102. return os.path.join(_default_dataset_root(), self.__class__.__name__)
  103. @property
  104. def valid_ground_truth(self):
  105. groud_truth_path = os.path.join(
  106. self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt"
  107. )
  108. if os.path.exists(groud_truth_path):
  109. with open(groud_truth_path, "r") as f:
  110. val_labels = f.readlines()
  111. return [int(val_label) for val_label in val_labels]
  112. else:
  113. raise FileNotFoundError(
  114. "valid ground truth file %s does not exist" % groud_truth_path
  115. )
  116. @property
  117. def meta(self):
  118. try:
  119. return load(os.path.join(self.devkit_dir, "meta.pkl"))
  120. except FileNotFoundError:
  121. import scipy.io
  122. meta_path = os.path.join(self.devkit_dir, "data", "meta.mat")
  123. if not os.path.exists(meta_path):
  124. raise FileNotFoundError("meta file %s does not exist" % meta_path)
  125. meta = scipy.io.loadmat(meta_path, squeeze_me=True)["synsets"]
  126. nums_children = list(zip(*meta))[4]
  127. meta = [
  128. meta[idx]
  129. for idx, num_children in enumerate(nums_children)
  130. if num_children == 0
  131. ]
  132. idcs, wnids, classes = list(zip(*meta))[:3]
  133. classes = [tuple(clss.split(", ")) for clss in classes]
  134. idx_to_wnid = dict(zip(idcs, wnids))
  135. wnid_to_classes = dict(zip(wnids, classes))
  136. logger.info(
  137. "saving cached meta file to %s",
  138. os.path.join(self.devkit_dir, "meta.pkl"),
  139. )
  140. save(
  141. (idx_to_wnid, wnid_to_classes),
  142. os.path.join(self.devkit_dir, "meta.pkl"),
  143. )
  144. return idx_to_wnid, wnid_to_classes
  145. def check_raw_file(self) -> bool:
  146. return all(
  147. [
  148. os.path.exists(os.path.join(self.root, value[0]))
  149. for _, value in self.raw_file_meta.items()
  150. ]
  151. )
  152. def _organize_val_data(self):
  153. id2wnid = self.meta[0]
  154. val_idcs = self.valid_ground_truth
  155. val_wnids = [id2wnid[idx] for idx in val_idcs]
  156. val_images = sorted(
  157. [
  158. os.path.join(self.target_folder, image)
  159. for image in os.listdir(self.target_folder)
  160. ]
  161. )
  162. logger.debug("mkdir for val set wnids")
  163. for wnid in set(val_wnids):
  164. os.makedirs(os.path.join(self.root, self.default_val_dir, wnid))
  165. logger.debug("mv val images into wnids dir")
  166. for wnid, img_file in tqdm(zip(val_wnids, val_images)):
  167. shutil.move(
  168. img_file,
  169. os.path.join(
  170. self.root, self.default_val_dir, wnid, os.path.basename(img_file)
  171. ),
  172. )
  173. def _prepare_val(self):
  174. assert not self.train
  175. raw_filename, checksum = self.raw_file_meta["val"]
  176. raw_file = os.path.join(self.root, raw_filename)
  177. logger.info("checksum valid tar file %s ...", raw_file)
  178. assert (
  179. calculate_md5(raw_file) == checksum
  180. ), "checksum mismatch, {} may be damaged".format(raw_file)
  181. logger.info("extract valid tar file... this may take 10-20 minutes")
  182. untar(os.path.join(self.root, raw_file), self.target_folder)
  183. self._organize_val_data()
  184. def _prepare_train(self):
  185. assert self.train
  186. raw_filename, checksum = self.raw_file_meta["train"]
  187. raw_file = os.path.join(self.root, raw_filename)
  188. logger.info("checksum train tar file %s ...", raw_file)
  189. assert (
  190. calculate_md5(raw_file) == checksum
  191. ), "checksum mismatch, {} may be damaged".format(raw_file)
  192. logger.info("extract train tar file.. this may take several hours")
  193. untar(
  194. os.path.join(self.root, raw_file), self.target_folder,
  195. )
  196. paths = [
  197. os.path.join(self.target_folder, child_dir)
  198. for child_dir in os.listdir(self.target_folder)
  199. ]
  200. for path in tqdm(paths):
  201. untar(path, os.path.splitext(path)[0], remove=True)
  202. def _prepare_devkit(self):
  203. raw_filename, checksum = self.raw_file_meta["devkit"]
  204. raw_file = os.path.join(self.root, raw_filename)
  205. logger.info("checksum devkit tar file %s ...", raw_file)
  206. assert (
  207. calculate_md5(raw_file) == checksum
  208. ), "checksum mismatch, {} may be damaged".format(raw_file)
  209. logger.info("extract devkit file..")
  210. untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0]))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台