Browse Source

fix(mge/data): create dir automatically when download=True

GitOrigin-RevId: f3cb1b7d50
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
aa2bfd2d85
2 changed files with 19 additions and 6 deletions
  1. +11
    -5
      python_module/megengine/data/dataset/vision/cifar.py
  2. +8
    -1
      python_module/megengine/data/dataset/vision/mnist.py

+ 11
- 5
python_module/megengine/data/dataset/vision/cifar.py View File

@@ -57,7 +57,14 @@ class CIFAR10(VisionDataset):
else: else:
self.root = root self.root = root
if not os.path.exists(self.root): if not os.path.exists(self.root):
raise ValueError("dir %s does not exist" % self.root)
if download:
logger.debug(
"dir %s does not exist, will be automatically created",
self.root,
)
os.makedirs(self.root)
else:
raise ValueError("dir %s does not exist" % self.root)


self.target_file = os.path.join(self.root, self.raw_file_dir) self.target_file = os.path.join(self.root, self.raw_file_dir)


@@ -77,8 +84,7 @@ class CIFAR10(VisionDataset):
self.arrays = self.bytes2array(self.test_batch) self.arrays = self.bytes2array(self.test_batch)
else: else:
raise ValueError( raise ValueError(
"dir does not contain target file\
%s, please set download=True"
"dir does not contain target file %s, please set download=True"
% (self.target_file) % (self.target_file)
) )


@@ -160,6 +166,6 @@ class CIFAR100(CIFAR10):
data.extend(list(batch_data[..., [2, 1, 0]])) data.extend(list(batch_data[..., [2, 1, 0]]))
fine_label.extend(dic[b"fine_labels"]) fine_label.extend(dic[b"fine_labels"])
coarse_label.extend(dic[b"coarse_labels"]) coarse_label.extend(dic[b"coarse_labels"])
fine_label = np.array(fine_label)
coarse_label = np.array(coarse_label)
fine_label = np.array(fine_label, dtype=np.int32)
coarse_label = np.array(coarse_label, dtype=np.int32)
return data, fine_label, coarse_label return data, fine_label, coarse_label

+ 8
- 1
python_module/megengine/data/dataset/vision/mnist.py View File

@@ -75,7 +75,14 @@ class MNIST(VisionDataset):
else: else:
self.root = root self.root = root
if not os.path.exists(self.root): if not os.path.exists(self.root):
raise ValueError("dir %s does not exist" % self.root)
if download:
logger.debug(
"dir %s does not exist, will be automatically created",
self.root,
)
os.makedirs(self.root)
else:
raise ValueError("dir %s does not exist" % self.root)


if self._check_raw_files(): if self._check_raw_files():
self.process(train) self.process(train)


Loading…
Cancel
Save