You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cifar.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import pickle
  11. import tarfile
  12. from typing import Tuple
  13. import numpy as np
  14. from ....logger import get_logger
  15. from .meta_vision import VisionDataset
  16. from .utils import _default_dataset_root, load_raw_data_from_url
  17. logger = get_logger(__name__)
  18. class CIFAR10(VisionDataset):
  19. r""":class:`~.Dataset` for CIFAR10 meta data."""
  20. url_path = "http://www.cs.utoronto.ca/~kriz/"
  21. raw_file_name = "cifar-10-python.tar.gz"
  22. raw_file_md5 = "c58f30108f718f92721af3b95e74349a"
  23. raw_file_dir = "cifar-10-batches-py"
  24. train_batch = [
  25. "data_batch_1",
  26. "data_batch_2",
  27. "data_batch_3",
  28. "data_batch_4",
  29. "data_batch_5",
  30. ]
  31. test_batch = ["test_batch"]
  32. meta_info = {"name": "batches.meta"}
  33. def __init__(
  34. self,
  35. root: str = None,
  36. train: bool = True,
  37. download: bool = True,
  38. timeout: int = 500,
  39. ):
  40. super().__init__(root, order=("image", "image_category"))
  41. self.timeout = timeout
  42. # process the root path
  43. if root is None:
  44. self.root = self._default_root
  45. if not os.path.exists(self.root):
  46. os.makedirs(self.root)
  47. else:
  48. self.root = root
  49. if not os.path.exists(self.root):
  50. if download:
  51. logger.debug(
  52. "dir %s does not exist, will be automatically created",
  53. self.root,
  54. )
  55. os.makedirs(self.root)
  56. else:
  57. raise ValueError("dir %s does not exist" % self.root)
  58. self.target_file = os.path.join(self.root, self.raw_file_dir)
  59. # check existence of target pickle dir, if exists load the
  60. # pickle file no matter what download is set
  61. if os.path.exists(self.target_file):
  62. if train:
  63. self.arrays = self.bytes2array(self.train_batch)
  64. else:
  65. self.arrays = self.bytes2array(self.test_batch)
  66. else:
  67. if download:
  68. self.download()
  69. if train:
  70. self.arrays = self.bytes2array(self.train_batch)
  71. else:
  72. self.arrays = self.bytes2array(self.test_batch)
  73. else:
  74. raise ValueError(
  75. "dir does not contain target file %s, please set download=True"
  76. % (self.target_file)
  77. )
  78. def __getitem__(self, index: int) -> Tuple:
  79. return tuple(array[index] for array in self.arrays)
  80. def __len__(self) -> int:
  81. return len(self.arrays[0])
  82. @property
  83. def _default_root(self):
  84. return os.path.join(_default_dataset_root(), self.__class__.__name__)
  85. @property
  86. def meta(self):
  87. meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"])
  88. with open(meta_path, "rb") as f:
  89. meta = pickle.load(f, encoding="bytes")
  90. return meta
  91. def download(self):
  92. url = self.url_path + self.raw_file_name
  93. load_raw_data_from_url(
  94. url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout
  95. )
  96. self.process()
  97. def untar(self, file_path, dirs):
  98. assert file_path.endswith(".tar.gz")
  99. logger.debug("untar file %s to %s", file_path, dirs)
  100. t = tarfile.open(file_path)
  101. t.extractall(path=dirs)
  102. def bytes2array(self, filenames):
  103. data = []
  104. label = []
  105. for filename in filenames:
  106. path = os.path.join(self.root, self.raw_file_dir, filename)
  107. logger.debug("unpickle file %s", path)
  108. with open(path, "rb") as fo:
  109. dic = pickle.load(fo, encoding="bytes")
  110. batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
  111. data.extend(list(batch_data[..., [2, 1, 0]]))
  112. label.extend(dic[b"labels"])
  113. label = np.array(label, dtype=np.int32)
  114. return (data, label)
  115. def process(self):
  116. logger.info("process raw data ...")
  117. self.untar(os.path.join(self.root, self.raw_file_name), self.root)
  118. class CIFAR100(CIFAR10):
  119. r""":class:`~.Dataset` for CIFAR100 meta data."""
  120. url_path = "http://www.cs.utoronto.ca/~kriz/"
  121. raw_file_name = "cifar-100-python.tar.gz"
  122. raw_file_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
  123. raw_file_dir = "cifar-100-python"
  124. train_batch = ["train"]
  125. test_batch = ["test"]
  126. meta_info = {"name": "meta"}
  127. @property
  128. def meta(self):
  129. meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"])
  130. with open(meta_path, "rb") as f:
  131. meta = pickle.load(f, encoding="bytes")
  132. return meta
  133. def bytes2array(self, filenames):
  134. data = []
  135. fine_label = []
  136. coarse_label = []
  137. for filename in filenames:
  138. path = os.path.join(self.root, self.raw_file_dir, filename)
  139. logger.debug("unpickle file %s", path)
  140. with open(path, "rb") as fo:
  141. dic = pickle.load(fo, encoding="bytes")
  142. batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
  143. data.extend(list(batch_data[..., [2, 1, 0]]))
  144. fine_label.extend(dic[b"fine_labels"])
  145. coarse_label.extend(dic[b"coarse_labels"])
  146. fine_label = np.array(fine_label, dtype=np.int32)
  147. coarse_label = np.array(coarse_label, dtype=np.int32)
  148. return data, fine_label, coarse_label

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台