You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

folder.py 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # -*- coding: utf-8 -*-
  2. # BSD 3-Clause License
  3. # Copyright (c) Soumith Chintala 2016,
  4. # All rights reserved.
  5. # ---------------------------------------------------------------------
  6. #
  7. # This file has been modified by Megvii ("Megvii Modifications").
  8. # All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
  9. # ---------------------------------------------------------------------
  10. import os
  11. from typing import Dict, List, Tuple
  12. import cv2
  13. import numpy as np
  14. from .meta_vision import VisionDataset
  15. from .utils import is_img
  16. class ImageFolder(VisionDataset):
  17. r"""ImageFolder is a class for loading image data and labels from a organized folder.
  18. The folder is expected to be organized as followed: root/cls/xxx.img_ext
  19. Labels are indices of sorted classes in the root directory.
  20. Args:
  21. root: root directory of an image folder.
  22. loader: a function used to load image from path,
  23. if ``None``, default function that loads
  24. images with PIL will be called.
  25. check_valid_func: a function used to check if files in folder are
  26. expected image files, if ``None``, default function
  27. that checks file extensions will be called.
  28. class_name: if ``True``, return class name instead of class index.
  29. """
  30. def __init__(self, root: str, check_valid_func=None, class_name: bool = False):
  31. super().__init__(root, order=("image", "image_category"))
  32. self.root = root
  33. if check_valid_func is not None:
  34. self.check_valid = check_valid_func
  35. else:
  36. self.check_valid = is_img
  37. self.class_name = class_name
  38. self.class_dict = self.collect_class()
  39. self.samples = self.collect_samples()
  40. def collect_samples(self) -> List:
  41. samples = []
  42. directory = os.path.expanduser(self.root)
  43. for key in sorted(self.class_dict.keys()):
  44. d = os.path.join(directory, key)
  45. if not os.path.isdir(d):
  46. continue
  47. for r, _, filename in sorted(os.walk(d, followlinks=True)):
  48. for name in sorted(filename):
  49. path = os.path.join(r, name)
  50. if self.check_valid(path):
  51. if self.class_name:
  52. samples.append((path, key))
  53. else:
  54. samples.append((path, self.class_dict[key]))
  55. return samples
  56. def collect_class(self) -> Dict:
  57. classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
  58. classes.sort()
  59. return {classes[i]: np.int32(i) for i in range(len(classes))}
  60. def __getitem__(self, index: int) -> Tuple:
  61. path, label = self.samples[index]
  62. img = cv2.imread(path, cv2.IMREAD_COLOR)
  63. return img, label
  64. def __len__(self):
  65. return len(self.samples)