You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

folder.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # -*- coding: utf-8 -*-
  2. # BSD 3-Clause License
  3. # Copyright (c) Soumith Chintala 2016,
  4. # All rights reserved.
  5. # ---------------------------------------------------------------------
  6. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  7. #
  8. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  9. #
  10. # Unless required by applicable law or agreed to in writing,
  11. # software distributed under the License is distributed on an
  12. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. #
  14. # This file has been modified by Megvii ("Megvii Modifications").
  15. # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
  16. # ---------------------------------------------------------------------
  17. import os
  18. from typing import Dict, List, Tuple
  19. import cv2
  20. import numpy as np
  21. from .meta_vision import VisionDataset
  22. from .utils import is_img
  23. class ImageFolder(VisionDataset):
  24. def __init__(self, root: str, check_valid_func=None, class_name: bool = False):
  25. r"""
  26. ImageFolder is a class for loading image data and labels from a organized folder.
  27. the folder is expected to be organized as followed
  28. root/cls/xxx.img_ext
  29. labels are indices of sorted classes in the root directory
  30. :param root: root directory of an image folder
  31. :param loader: a function used to load image from path,
  32. if ``None``, default function that loads
  33. images with PILwill be called
  34. :param check_valid_func: a function used to check if files in folder are
  35. expected image files, if ``None``, default function
  36. that checks file extensions will be called
  37. :param class_name: if ``True``, return class name instead of class index
  38. """
  39. super().__init__(root, order=("image", "image_category"))
  40. self.root = root
  41. if check_valid_func is not None:
  42. self.check_valid = check_valid_func
  43. else:
  44. self.check_valid = is_img
  45. self.class_name = class_name
  46. self.class_dict = self.collect_class()
  47. self.samples = self.collect_samples()
  48. def collect_samples(self) -> List:
  49. samples = []
  50. directory = os.path.expanduser(self.root)
  51. for key in sorted(self.class_dict.keys()):
  52. d = os.path.join(directory, key)
  53. if not os.path.isdir(d):
  54. continue
  55. for r, _, filename in sorted(os.walk(d, followlinks=True)):
  56. for name in sorted(filename):
  57. path = os.path.join(r, name)
  58. if self.check_valid(path):
  59. if self.class_name:
  60. samples.append((path, key))
  61. else:
  62. samples.append((path, self.class_dict[key]))
  63. return samples
  64. def collect_class(self) -> Dict:
  65. classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
  66. classes.sort()
  67. return {classes[i]: np.int32(i) for i in range(len(classes))}
  68. def __getitem__(self, index: int) -> Tuple:
  69. path, label = self.samples[index]
  70. img = cv2.imread(path, cv2.IMREAD_COLOR)
  71. return img, label
  72. def __len__(self):
  73. return len(self.samples)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台