|
- # -*- coding: utf-8 -*-
- # BSD 3-Clause License
-
- # Copyright (c) Soumith Chintala 2016,
- # All rights reserved.
- # ---------------------------------------------------------------------
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- #
- # This file has been modified by Megvii ("Megvii Modifications").
- # All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
- # ---------------------------------------------------------------------
- import os
- from typing import Dict, List, Tuple
-
- import cv2
- import numpy as np
-
- from .meta_vision import VisionDataset
- from .utils import is_img
-
-
- class ImageFolder(VisionDataset):
- r"""ImageFolder is a class for loading image data and labels from a organized folder.
-
- The folder is expected to be organized as followed: root/cls/xxx.img_ext
-
- Labels are indices of sorted classes in the root directory.
-
- Args:
- root: root directory of an image folder.
- loader: a function used to load image from path,
- if ``None``, default function that loads
- images with PIL will be called.
- check_valid_func: a function used to check if files in folder are
- expected image files, if ``None``, default function
- that checks file extensions will be called.
- class_name: if ``True``, return class name instead of class index.
- """
-
- def __init__(self, root: str, check_valid_func=None, class_name: bool = False):
- super().__init__(root, order=("image", "image_category"))
-
- self.root = root
-
- if check_valid_func is not None:
- self.check_valid = check_valid_func
- else:
- self.check_valid = is_img
-
- self.class_name = class_name
-
- self.class_dict = self.collect_class()
- self.samples = self.collect_samples()
-
- def collect_samples(self) -> List:
- samples = []
- directory = os.path.expanduser(self.root)
- for key in sorted(self.class_dict.keys()):
- d = os.path.join(directory, key)
- if not os.path.isdir(d):
- continue
- for r, _, filename in sorted(os.walk(d, followlinks=True)):
- for name in sorted(filename):
- path = os.path.join(r, name)
- if self.check_valid(path):
- if self.class_name:
- samples.append((path, key))
- else:
- samples.append((path, self.class_dict[key]))
- return samples
-
- def collect_class(self) -> Dict:
- classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
- classes.sort()
- return {classes[i]: np.int32(i) for i in range(len(classes))}
-
- def __getitem__(self, index: int) -> Tuple:
- path, label = self.samples[index]
- img = cv2.imread(path, cv2.IMREAD_COLOR)
- return img, label
-
- def __len__(self):
- return len(self.samples)
|