You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cifar.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import pickle
  11. import tarfile
  12. from typing import Tuple
  13. import numpy as np
  14. from ....logger import get_logger
  15. from .meta_vision import VisionDataset
  16. from .utils import _default_dataset_root, load_raw_data_from_url
  17. logger = get_logger(__name__)
  18. class CIFAR10(VisionDataset):
  19. r""" ``Dataset`` for CIFAR10 meta data.
  20. """
  21. url_path = "http://www.cs.utoronto.ca/~kriz/"
  22. raw_file_name = "cifar-10-python.tar.gz"
  23. raw_file_md5 = "c58f30108f718f92721af3b95e74349a"
  24. raw_file_dir = "cifar-10-batches-py"
  25. train_batch = [
  26. "data_batch_1",
  27. "data_batch_2",
  28. "data_batch_3",
  29. "data_batch_4",
  30. "data_batch_5",
  31. ]
  32. test_batch = ["test_batch"]
  33. meta_info = {"name": "batches.meta"}
  34. def __init__(
  35. self,
  36. root: str = None,
  37. train: bool = True,
  38. download: bool = True,
  39. timeout: int = 500,
  40. ):
  41. super().__init__(root, order=("image", "image_category"))
  42. self.timeout = timeout
  43. # process the root path
  44. if root is None:
  45. self.root = self._default_root
  46. if not os.path.exists(self.root):
  47. os.makedirs(self.root)
  48. else:
  49. self.root = root
  50. if not os.path.exists(self.root):
  51. if download:
  52. logger.debug(
  53. "dir %s does not exist, will be automatically created",
  54. self.root,
  55. )
  56. os.makedirs(self.root)
  57. else:
  58. raise ValueError("dir %s does not exist" % self.root)
  59. self.target_file = os.path.join(self.root, self.raw_file_dir)
  60. # check existence of target pickle dir, if exists load the
  61. # pickle file no matter what download is set
  62. if os.path.exists(self.target_file):
  63. if train:
  64. self.arrays = self.bytes2array(self.train_batch)
  65. else:
  66. self.arrays = self.bytes2array(self.test_batch)
  67. else:
  68. if download:
  69. self.download()
  70. if train:
  71. self.arrays = self.bytes2array(self.train_batch)
  72. else:
  73. self.arrays = self.bytes2array(self.test_batch)
  74. else:
  75. raise ValueError(
  76. "dir does not contain target file %s, please set download=True"
  77. % (self.target_file)
  78. )
  79. def __getitem__(self, index: int) -> Tuple:
  80. return tuple(array[index] for array in self.arrays)
  81. def __len__(self) -> int:
  82. return len(self.arrays[0])
  83. @property
  84. def _default_root(self):
  85. return os.path.join(_default_dataset_root(), self.__class__.__name__)
  86. @property
  87. def meta(self):
  88. meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"])
  89. with open(meta_path, "rb") as f:
  90. meta = pickle.load(f, encoding="bytes")
  91. return meta
  92. def download(self):
  93. url = self.url_path + self.raw_file_name
  94. load_raw_data_from_url(
  95. url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout
  96. )
  97. self.process()
  98. def untar(self, file_path, dirs):
  99. assert file_path.endswith(".tar.gz")
  100. logger.debug("untar file %s to %s", file_path, dirs)
  101. t = tarfile.open(file_path)
  102. t.extractall(path=dirs)
  103. def bytes2array(self, filenames):
  104. data = []
  105. label = []
  106. for filename in filenames:
  107. path = os.path.join(self.root, self.raw_file_dir, filename)
  108. logger.debug("unpickle file %s", path)
  109. with open(path, "rb") as fo:
  110. dic = pickle.load(fo, encoding="bytes")
  111. batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
  112. data.extend(list(batch_data[..., [2, 1, 0]]))
  113. label.extend(dic[b"labels"])
  114. label = np.array(label, dtype=np.int32)
  115. return (data, label)
  116. def process(self):
  117. logger.info("process raw data ...")
  118. self.untar(os.path.join(self.root, self.raw_file_name), self.root)
  119. class CIFAR100(CIFAR10):
  120. url_path = "http://www.cs.utoronto.ca/~kriz/"
  121. raw_file_name = "cifar-100-python.tar.gz"
  122. raw_file_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
  123. raw_file_dir = "cifar-100-python"
  124. train_batch = ["train"]
  125. test_batch = ["test"]
  126. meta_info = {"name": "meta"}
  127. @property
  128. def meta(self):
  129. meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"])
  130. with open(meta_path, "rb") as f:
  131. meta = pickle.load(f, encoding="bytes")
  132. return meta
  133. def bytes2array(self, filenames):
  134. data = []
  135. fine_label = []
  136. coarse_label = []
  137. for filename in filenames:
  138. path = os.path.join(self.root, self.raw_file_dir, filename)
  139. logger.debug("unpickle file %s", path)
  140. with open(path, "rb") as fo:
  141. dic = pickle.load(fo, encoding="bytes")
  142. batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
  143. data.extend(list(batch_data[..., [2, 1, 0]]))
  144. fine_label.extend(dic[b"fine_labels"])
  145. coarse_label.extend(dic[b"coarse_labels"])
  146. fine_label = np.array(fine_label, dtype=np.int32)
  147. coarse_label = np.array(coarse_label, dtype=np.int32)
  148. return data, fine_label, coarse_label

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台