You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

folder.py 3.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # -*- coding: utf-8 -*-
  2. # BSD 3-Clause License
  3. # Copyright (c) Soumith Chintala 2016,
  4. # All rights reserved.
  5. # ---------------------------------------------------------------------
  6. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  7. #
  8. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  9. #
  10. # Unless required by applicable law or agreed to in writing,
  11. # software distributed under the License is distributed on an
  12. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. #
  14. # This file has been modified by Megvii ("Megvii Modifications").
  15. # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
  16. # ---------------------------------------------------------------------
  17. import os
  18. from typing import Dict, List, Tuple
  19. import cv2
  20. import numpy as np
  21. from .meta_vision import VisionDataset
  22. from .utils import is_img
  23. class ImageFolder(VisionDataset):
  24. def __init__(self, root: str, check_valid_func=None, class_name: bool = False):
  25. r"""
  26. ImageFolder is a class for loading image data and labels from a organized folder.
  27. The folder is expected to be organized as followed: root/cls/xxx.img_ext
  28. Labels are indices of sorted classes in the root directory.
  29. :param root: root directory of an image folder.
  30. :param loader: a function used to load image from path,
  31. if ``None``, default function that loads
  32. images with PIL will be called.
  33. :param check_valid_func: a function used to check if files in folder are
  34. expected image files, if ``None``, default function
  35. that checks file extensions will be called.
  36. :param class_name: if ``True``, return class name instead of class index.
  37. """
  38. super().__init__(root, order=("image", "image_category"))
  39. self.root = root
  40. if check_valid_func is not None:
  41. self.check_valid = check_valid_func
  42. else:
  43. self.check_valid = is_img
  44. self.class_name = class_name
  45. self.class_dict = self.collect_class()
  46. self.samples = self.collect_samples()
  47. def collect_samples(self) -> List:
  48. samples = []
  49. directory = os.path.expanduser(self.root)
  50. for key in sorted(self.class_dict.keys()):
  51. d = os.path.join(directory, key)
  52. if not os.path.isdir(d):
  53. continue
  54. for r, _, filename in sorted(os.walk(d, followlinks=True)):
  55. for name in sorted(filename):
  56. path = os.path.join(r, name)
  57. if self.check_valid(path):
  58. if self.class_name:
  59. samples.append((path, key))
  60. else:
  61. samples.append((path, self.class_dict[key]))
  62. return samples
  63. def collect_class(self) -> Dict:
  64. classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
  65. classes.sort()
  66. return {classes[i]: np.int32(i) for i in range(len(classes))}
  67. def __getitem__(self, index: int) -> Tuple:
  68. path, label = self.samples[index]
  69. img = cv2.imread(path, cv2.IMREAD_COLOR)
  70. return img, label
  71. def __len__(self):
  72. return len(self.samples)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台