# -*- coding: utf-8 -*- # BSD 3-Clause License # # Copyright (c) Soumith Chintala 2016, # All rights reserved. # --------------------------------------------------------------------- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # # This file has been modified by Megvii ("Megvii Modifications"). # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. # --------------------------------------------------------------------- import os import shutil from tqdm import tqdm from ....core.serialization import load, save from ....distributed.util import is_distributed from ....logger import get_logger from .folder import ImageFolder from .utils import _default_dataset_root, untar, untargz logger = get_logger(__name__) class ImageNet(ImageFolder): r""" Load ImageNet from raw files or folder, expected folder looks like raw files situation (optional): root/ILSVRC2012_img_train.tar root/ILSVRC2012_img_val.tar root/ILSVRC2012_devkit_t12.tar.gz image folder situation (required): root/train/cls/xxx.${img_ext} root/val/cls/xxx.${img_ext} root/ILSVRC2012_devkit_t12/data/meta.mat root/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt If the required folders don't exist, raw files are required to get extracted and processed. """ raw_file_meta = { "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), } """ raw files of ImageNet (train, val, devkit) """ default_train_dir = "train" """ directory of train data """ default_val_dir = "val" """ directory of val data """ default_devkit_dir = "ILSVRC2012_devkit_t12" """ directory of devkit """ def __init__(self, root: str = None, train: bool = True, **kwargs): r""" initialization: * if ``root`` contains ``self.target_folder`` depent on ``train``: * initialize ImageFolder with target_folder * else: * if all raw files are in ``root``: * parse ``self.target_folder`` from raw files * initialize ImageFolder with ``self.target_folder`` * else: * raise error :param root: root directory of imagenet data, if root is ``None``, used default_dataset_root :param train: if ``True``, load the train split, otherwise load the validation split """ # process the root path if root is None: self.root = self._default_root else: self.root = root self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) if not os.path.exists(self.root): raise FileNotFoundError("dir %s does not exist" % self.root) if not os.path.exists(self.devkit_dir): logger.warning("devkit directory %s does not exists" % self.devkit_dir) if train: self.target_folder = os.path.join(self.root, self.default_train_dir) else: self.target_folder = os.path.join(self.root, self.default_val_dir) if not os.path.exists(self.target_folder): logger.warning( "expected image folder %s does not exist, try to load from raw file" % self.target_folder ) if not self.check_raw_file(): raise FileNotFoundError( "expected image folder %s does not exist, and raw files do not exist in %s" % (self.target_folder, self.root) ) elif is_distributed(): raise RuntimeError( "extracting raw file shouldn't be done in distributed mode, use single process instead" ) else: self.parse(train) super().__init__(self.target_folder, **kwargs) @property def _default_root(self): return os.path.join(_default_dataset_root(), self.__class__.__name__) @property def valid_ground_truth(self): groud_truth_path = os.path.join( self.devkit_dir, "data", "ILSVRC2012_validation_ground_truth.txt" ) if os.path.exists(groud_truth_path): with open(groud_truth_path, "r") as f: val_labels = f.readlines() return [int(val_label) for val_label in val_labels] else: raise FileNotFoundError( "valid ground truth file %s does not exist" % groud_truth_path ) @property def meta(self): try: return load(os.path.join(self.devkit_dir, "meta.pkl")) except FileNotFoundError: import scipy.io as sio meta_path = os.path.join(self.devkit_dir, "data", "meta.mat") if not os.path.exists(meta_path): raise FileNotFoundError("meta file %s does not exist" % meta_path) meta = sio.loadmat(meta_path, squeeze_me=True)["synsets"] nums_children = list(zip(*meta))[4] meta = [ meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0 ] idcs, wnids, classes = list(zip(*meta))[:3] classes = [tuple(clss.split(", ")) for clss in classes] idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} logger.info("saving cached meta file to %s", os.path.join(self.devkit_dir, "meta.pkl")) save((idx_to_wnid, wnid_to_classes), os.path.join(self.devkit_dir, "meta.pkl")) return idx_to_wnid, wnid_to_classes def check_raw_file(self) -> bool: return all( [ os.path.exists(os.path.join(self.root, value[0])) for _, value in self.raw_file_meta.items() ] ) def organize_val_data(self): id2wnid = self.meta[0] val_idcs = self.valid_ground_truth val_wnids = [id2wnid[idx] for idx in val_idcs] raw_val_dir = os.path.join(self.root, "ILSVRC2012_img_val") val_images = sorted( [os.path.join(raw_val_dir, image) for image in os.listdir(raw_val_dir)] ) logger.debug("mkdir for val set wnids") for wnid in set(val_wnids): os.makedirs(os.path.join(self.root, self.default_val_dir, wnid)) logger.debug("mv val images into wnids dir") for wnid, img_file in tqdm(zip(val_wnids, val_images)): shutil.move( img_file, os.path.join( self.root, self.default_val_dir, wnid, os.path.basename(img_file) ), ) def parse(self, train): if train: logger.info("process train raw file.. this may take several hours") untar( os.path.join(self.root, self.raw_file_meta["train"][0]), self.target_folder, ) paths = [ os.path.join(self.target_folder, child_dir) for child_dir in os.listdir(self.target_folder) ] for path in tqdm(paths): untar(path, os.path.splitext(path)[0], remove=True) else: logger.info("process devkit file..") untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) logger.info("process valid raw file.. this may take 10-20 minutes") untar(os.path.join(self.root, self.raw_file_meta["val"][0])) self.organize_val_data()