|
@@ -24,7 +24,7 @@ from ....core.serialization import load, save |
|
|
from ....distributed.util import is_distributed |
|
|
from ....distributed.util import is_distributed |
|
|
from ....logger import get_logger |
|
|
from ....logger import get_logger |
|
|
from .folder import ImageFolder |
|
|
from .folder import ImageFolder |
|
|
from .utils import _default_dataset_root, untar, untargz |
|
|
|
|
|
|
|
|
from .utils import _default_dataset_root, calculate_md5, untar, untargz |
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
@@ -33,40 +33,28 @@ class ImageNet(ImageFolder): |
|
|
r""" |
|
|
r""" |
|
|
Load ImageNet from raw files or folder, expected folder looks like |
|
|
Load ImageNet from raw files or folder, expected folder looks like |
|
|
|
|
|
|
|
|
raw files situation (optional): |
|
|
|
|
|
root/ILSVRC2012_img_train.tar |
|
|
|
|
|
root/ILSVRC2012_img_val.tar |
|
|
|
|
|
root/ILSVRC2012_devkit_t12.tar.gz |
|
|
|
|
|
|
|
|
|
|
|
image folder situation (required): |
|
|
|
|
|
root/train/cls/xxx.${img_ext} |
|
|
|
|
|
root/val/cls/xxx.${img_ext} |
|
|
|
|
|
root/ILSVRC2012_devkit_t12/data/meta.mat |
|
|
|
|
|
root/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt |
|
|
|
|
|
|
|
|
|
|
|
If the required folders don't exist, raw files are required to get extracted and processed. |
|
|
|
|
|
|
|
|
${root}/ |
|
|
|
|
|
| [REQUIRED TAR FILES] |
|
|
|
|
|
|- ILSVRC2012_img_train.tar |
|
|
|
|
|
|- ILSVRC2012_img_val.tar |
|
|
|
|
|
|- ILSVRC2012_devkit_t12.tar.gz |
|
|
|
|
|
| [OPTIONAL IMAGE FOLDERS] |
|
|
|
|
|
|- train/cls/xxx.${img_ext} |
|
|
|
|
|
|- val/cls/xxx.${img_ext} |
|
|
|
|
|
|- ILSVRC2012_devkit_t12/data/meta.mat |
|
|
|
|
|
|- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt |
|
|
|
|
|
|
|
|
|
|
|
If the image folders don't exist, raw tar files are required to get extracted and processed. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
raw_file_meta = { |
|
|
raw_file_meta = { |
|
|
"train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), |
|
|
"train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), |
|
|
"val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), |
|
|
"val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), |
|
|
"devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), |
|
|
"devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), |
|
|
} |
|
|
|
|
|
""" |
|
|
|
|
|
raw files of ImageNet (train, val, devkit) |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
} # ImageNet raw files |
|
|
default_train_dir = "train" |
|
|
default_train_dir = "train" |
|
|
""" |
|
|
|
|
|
directory of train data |
|
|
|
|
|
""" |
|
|
|
|
|
default_val_dir = "val" |
|
|
default_val_dir = "val" |
|
|
""" |
|
|
|
|
|
directory of val data |
|
|
|
|
|
""" |
|
|
|
|
|
default_devkit_dir = "ILSVRC2012_devkit_t12" |
|
|
default_devkit_dir = "ILSVRC2012_devkit_t12" |
|
|
""" |
|
|
|
|
|
directory of devkit |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, root: str = None, train: bool = True, **kwargs): |
|
|
def __init__(self, root: str = None, train: bool = True, **kwargs): |
|
|
r""" |
|
|
r""" |
|
@@ -97,13 +85,16 @@ class ImageNet(ImageFolder): |
|
|
else: |
|
|
else: |
|
|
self.root = root |
|
|
self.root = root |
|
|
|
|
|
|
|
|
self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) |
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(self.root): |
|
|
if not os.path.exists(self.root): |
|
|
raise FileNotFoundError("dir %s does not exist" % self.root) |
|
|
raise FileNotFoundError("dir %s does not exist" % self.root) |
|
|
|
|
|
|
|
|
|
|
|
self.devkit_dir = os.path.join(self.root, self.default_devkit_dir) |
|
|
|
|
|
|
|
|
if not os.path.exists(self.devkit_dir): |
|
|
if not os.path.exists(self.devkit_dir): |
|
|
logger.warning("devkit directory %s does not exists" % self.devkit_dir) |
|
|
logger.warning("devkit directory %s does not exists" % self.devkit_dir) |
|
|
|
|
|
self._prepare_devkit() |
|
|
|
|
|
|
|
|
|
|
|
self.train = train |
|
|
|
|
|
|
|
|
if train: |
|
|
if train: |
|
|
self.target_folder = os.path.join(self.root, self.default_train_dir) |
|
|
self.target_folder = os.path.join(self.root, self.default_train_dir) |
|
@@ -125,7 +116,7 @@ class ImageNet(ImageFolder): |
|
|
"extracting raw file shouldn't be done in distributed mode, use single process instead" |
|
|
"extracting raw file shouldn't be done in distributed mode, use single process instead" |
|
|
) |
|
|
) |
|
|
else: |
|
|
else: |
|
|
self.parse(train) |
|
|
|
|
|
|
|
|
self._prepare_train() if train else self._prepare_val() |
|
|
|
|
|
|
|
|
super().__init__(self.target_folder, **kwargs) |
|
|
super().__init__(self.target_folder, **kwargs) |
|
|
|
|
|
|
|
@@ -180,14 +171,13 @@ class ImageNet(ImageFolder): |
|
|
] |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def organize_val_data(self): |
|
|
|
|
|
|
|
|
def _organize_val_data(self): |
|
|
id2wnid = self.meta[0] |
|
|
id2wnid = self.meta[0] |
|
|
val_idcs = self.valid_ground_truth |
|
|
val_idcs = self.valid_ground_truth |
|
|
val_wnids = [id2wnid[idx] for idx in val_idcs] |
|
|
val_wnids = [id2wnid[idx] for idx in val_idcs] |
|
|
|
|
|
|
|
|
raw_val_dir = os.path.join(self.root, "ILSVRC2012_img_val") |
|
|
|
|
|
val_images = sorted( |
|
|
val_images = sorted( |
|
|
[os.path.join(raw_val_dir, image) for image in os.listdir(raw_val_dir)] |
|
|
|
|
|
|
|
|
[os.path.join(self.target_folder, image) for image in os.listdir(self.target_folder)] |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
logger.debug("mkdir for val set wnids") |
|
|
logger.debug("mkdir for val set wnids") |
|
@@ -203,24 +193,41 @@ class ImageNet(ImageFolder): |
|
|
), |
|
|
), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def parse(self, train): |
|
|
|
|
|
if train: |
|
|
|
|
|
logger.info("process train raw file.. this may take several hours") |
|
|
|
|
|
untar( |
|
|
|
|
|
os.path.join(self.root, self.raw_file_meta["train"][0]), |
|
|
|
|
|
self.target_folder, |
|
|
|
|
|
) |
|
|
|
|
|
paths = [ |
|
|
|
|
|
os.path.join(self.target_folder, child_dir) |
|
|
|
|
|
for child_dir in os.listdir(self.target_folder) |
|
|
|
|
|
] |
|
|
|
|
|
for path in tqdm(paths): |
|
|
|
|
|
untar(path, os.path.splitext(path)[0], remove=True) |
|
|
|
|
|
else: |
|
|
|
|
|
logger.info("process devkit file..") |
|
|
|
|
|
untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) |
|
|
|
|
|
logger.info("process valid raw file.. this may take 10-20 minutes") |
|
|
|
|
|
raw_val_dir = os.path.join(self.root, "ILSVRC2012_img_val") |
|
|
|
|
|
os.makedirs(raw_val_dir, exist_ok=True) |
|
|
|
|
|
untar(os.path.join(self.root, self.raw_file_meta["val"][0]), raw_val_dir) |
|
|
|
|
|
self.organize_val_data() |
|
|
|
|
|
|
|
|
def _prepare_val(self): |
|
|
|
|
|
assert not self.train |
|
|
|
|
|
raw_filename, checksum = self.raw_file_meta["val"] |
|
|
|
|
|
raw_file = os.path.join(self.root, raw_filename) |
|
|
|
|
|
logger.info("checksum valid tar file {} ..".format(raw_file)) |
|
|
|
|
|
assert calculate_md5(raw_file) == checksum, \ |
|
|
|
|
|
"checksum mismatch, {} may be damaged".format(raw_file) |
|
|
|
|
|
logger.info("extract valid tar file.. this may take 10-20 minutes") |
|
|
|
|
|
untar(os.path.join(self.root, raw_file), self.target_folder) |
|
|
|
|
|
self._organize_val_data() |
|
|
|
|
|
|
|
|
|
|
|
def _prepare_train(self): |
|
|
|
|
|
assert self.train |
|
|
|
|
|
raw_filename, checksum = self.raw_file_meta["train"] |
|
|
|
|
|
raw_file = os.path.join(self.root, raw_filename) |
|
|
|
|
|
logger.info("checksum train tar file {} ..".format(raw_file)) |
|
|
|
|
|
assert calculate_md5(raw_file) == checksum, \ |
|
|
|
|
|
"checksum mismatch, {} may be damaged".format(raw_file) |
|
|
|
|
|
logger.info("extract train tar file.. this may take several hours") |
|
|
|
|
|
untar( |
|
|
|
|
|
os.path.join(self.root, raw_file), |
|
|
|
|
|
self.target_folder, |
|
|
|
|
|
) |
|
|
|
|
|
paths = [ |
|
|
|
|
|
os.path.join(self.target_folder, child_dir) |
|
|
|
|
|
for child_dir in os.listdir(self.target_folder) |
|
|
|
|
|
] |
|
|
|
|
|
for path in tqdm(paths): |
|
|
|
|
|
untar(path, os.path.splitext(path)[0], remove=True) |
|
|
|
|
|
|
|
|
|
|
|
def _prepare_devkit(self): |
|
|
|
|
|
raw_filename, checksum = self.raw_file_meta["val"] |
|
|
|
|
|
raw_file = os.path.join(self.root, raw_filename) |
|
|
|
|
|
logger.info("checksum devkit tar file {} ..".format(raw_file)) |
|
|
|
|
|
assert calculate_md5(raw_file) == checksum, \ |
|
|
|
|
|
"checksum mismatch, {} may be damaged".format(raw_file) |
|
|
|
|
|
logger.info("extract devkit file..") |
|
|
|
|
|
untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0])) |