You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mnist.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # -*- coding: utf-8 -*-
  2. import gzip
  3. import os
  4. import struct
  5. from typing import Tuple
  6. import numpy as np
  7. from tqdm import tqdm
  8. from ....logger import get_logger
  9. from .meta_vision import VisionDataset
  10. from .utils import _default_dataset_root, load_raw_data_from_url
  11. logger = get_logger(__name__)
  12. class MNIST(VisionDataset):
  13. r"""MNIST dataset.
  14. The MNIST_ database (Modified National Institute of Standards and Technology database)
  15. is a large database of handwritten digits that is commonly used for training various image processing systems.
  16. The database is also widely used for training and testing in the field of machine learning.
  17. It was created by "re-mixing" the samples from `NIST`_'s original datasets.
  18. Furthermore, the black and white images from NIST were normalized to fit into a 28x28 pixel
  19. bounding box and anti-aliased, which introduced grayscale levels.
  20. The MNIST database contains 60,000 training images and 10,000 testing images.
  21. The above introduction comes from `MNIST database - Wikipedia
  22. <https://en.wikipedia.org/wiki/MNIST_database>`_.
  23. Args:
  24. root: Path for MNIST dataset downloading or loading. If it's ``None``,
  25. it will be set to ``~/.cache/megengine`` (the default root path).
  26. train: If ``True``, use traning dataset; Otherwise use the test set.
  27. download: If ``True``, downloads the dataset from the internet and puts it in ``root`` directory.
  28. If dataset is already downloaded, it is not downloaded again.
  29. Returns:
  30. The MNIST :class:`~.Dataset` that can work with :class:`~.DataLoader`.
  31. Example:
  32. >>> from megengine.data.dataset import MNIST # doctest: +SKIP
  33. >>> mnist = MNIST("/data/datasets/MNIST") # Set the root path # doctest: +SKIP
  34. >>> image, label = mnist[0] # doctest: +SKIP
  35. >>> image.shape # doctest: +SKIP
  36. (28, 28, 1)
  37. .. versionchanged:: 1.11 The original URL has been updated to a mirror URL
  38. *"Please refrain from accessing these files from automated scripts with high frequency. Make copies!"*
  39. As requested by the original provider of the MNIST dataset,
  40. now the dataset will be downloaded from the mirror site:
  41. https://ossci-datasets.s3.amazonaws.com/mnist/
  42. .. seealso::
  43. * MNIST dataset is used in :ref:`megengine-quick-start` tutorial as an example.
  44. * You can find a lot of machine learning projects using MNIST dataset on the internet.
  45. .. _MNIST: http://yann.lecun.com/exdb/mnist/
  46. .. _NIST: https://www.nist.gov/data
  47. """
  48. url_path = "https://ossci-datasets.s3.amazonaws.com/mnist/"
  49. raw_file_name = [
  50. "train-images-idx3-ubyte.gz",
  51. "train-labels-idx1-ubyte.gz",
  52. "t10k-images-idx3-ubyte.gz",
  53. "t10k-labels-idx1-ubyte.gz",
  54. ]
  55. raw_file_md5 = [
  56. "f68b3c2dcbeaaa9fbdd348bbdeb94873",
  57. "d53e105ee54ea40749a09fcbcd1e9432",
  58. "9fb629c4189551a2d022fa330f9573f3",
  59. "ec29112dd5afa0611ce80d1b7f02629c",
  60. ]
  61. def __init__(
  62. self, root: str = None, train: bool = True, download: bool = True,
  63. ):
  64. super().__init__(root, order=("image", "image_category"))
  65. # process the root path
  66. if root is None:
  67. self.root = self._default_root
  68. if not os.path.exists(self.root):
  69. os.makedirs(self.root)
  70. else:
  71. self.root = root
  72. if not os.path.exists(self.root):
  73. if download:
  74. logger.debug(
  75. "dir %s does not exist, will be automatically created",
  76. self.root,
  77. )
  78. os.makedirs(self.root)
  79. else:
  80. raise ValueError("dir %s does not exist" % self.root)
  81. if self._check_raw_files():
  82. self.process(train)
  83. elif download:
  84. self.download()
  85. self.process(train)
  86. else:
  87. raise ValueError(
  88. "root does not contain valid raw files, please set download=True"
  89. )
  90. def __getitem__(self, index: int) -> Tuple:
  91. return tuple(array[index] for array in self.arrays)
  92. def __len__(self) -> int:
  93. return len(self.arrays[0])
  94. @property
  95. def _default_root(self):
  96. return os.path.join(_default_dataset_root(), self.__class__.__name__)
  97. @property
  98. def meta(self):
  99. return self._meta_data
  100. def _check_raw_files(self):
  101. return all(
  102. [
  103. os.path.exists(os.path.join(self.root, path))
  104. for path in self.raw_file_name
  105. ]
  106. )
  107. def download(self):
  108. for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5):
  109. url = self.url_path + file_name
  110. load_raw_data_from_url(url, file_name, md5, self.root)
  111. def process(self, train):
  112. # load raw files and transform them into meta data and datasets Tuple(np.array)
  113. logger.info("process the raw files of %s set...", "train" if train else "test")
  114. if train:
  115. meta_data_images, images = parse_idx3(
  116. os.path.join(self.root, self.raw_file_name[0])
  117. )
  118. meta_data_labels, labels = parse_idx1(
  119. os.path.join(self.root, self.raw_file_name[1])
  120. )
  121. else:
  122. meta_data_images, images = parse_idx3(
  123. os.path.join(self.root, self.raw_file_name[2])
  124. )
  125. meta_data_labels, labels = parse_idx1(
  126. os.path.join(self.root, self.raw_file_name[3])
  127. )
  128. self._meta_data = {
  129. "images": meta_data_images,
  130. "labels": meta_data_labels,
  131. }
  132. self.arrays = (images, labels.astype(np.int32))
  133. def parse_idx3(idx3_file):
  134. # parse idx3 file to meta data and data in numpy array (images)
  135. logger.debug("parse idx3 file %s ...", idx3_file)
  136. assert idx3_file.endswith(".gz")
  137. with gzip.open(idx3_file, "rb") as f:
  138. bin_data = f.read()
  139. # parse meta data
  140. offset = 0
  141. fmt_header = ">iiii"
  142. magic, imgs, height, width = struct.unpack_from(fmt_header, bin_data, offset)
  143. meta_data = {"magic": magic, "imgs": imgs, "height": height, "width": width}
  144. # parse images
  145. image_size = height * width
  146. offset += struct.calcsize(fmt_header)
  147. fmt_image = ">" + str(image_size) + "B"
  148. images = []
  149. bar = tqdm(total=meta_data["imgs"], ncols=80)
  150. for image in struct.iter_unpack(fmt_image, bin_data[offset:]):
  151. images.append(np.array(image, dtype=np.uint8).reshape((height, width, 1)))
  152. bar.update()
  153. bar.close()
  154. return meta_data, images
  155. def parse_idx1(idx1_file):
  156. # parse idx1 file to meta data and data in numpy array (labels)
  157. logger.debug("parse idx1 file %s ...", idx1_file)
  158. assert idx1_file.endswith(".gz")
  159. with gzip.open(idx1_file, "rb") as f:
  160. bin_data = f.read()
  161. # parse meta data
  162. offset = 0
  163. fmt_header = ">ii"
  164. magic, imgs = struct.unpack_from(fmt_header, bin_data, offset)
  165. meta_data = {"magic": magic, "imgs": imgs}
  166. # parse labels
  167. offset += struct.calcsize(fmt_header)
  168. fmt_image = ">B"
  169. labels = np.empty(imgs, dtype=int)
  170. bar = tqdm(total=meta_data["imgs"], ncols=80)
  171. for i, label in enumerate(struct.iter_unpack(fmt_image, bin_data[offset:])):
  172. labels[i] = label[0]
  173. bar.update()
  174. bar.close()
  175. return meta_data, labels