|
|
@@ -57,7 +57,14 @@ class CIFAR10(VisionDataset): |
|
|
|
else: |
|
|
|
self.root = root |
|
|
|
if not os.path.exists(self.root): |
|
|
|
raise ValueError("dir %s does not exist" % self.root) |
|
|
|
if download: |
|
|
|
logger.debug( |
|
|
|
"dir %s does not exist, will be automatically created", |
|
|
|
self.root, |
|
|
|
) |
|
|
|
os.makedirs(self.root) |
|
|
|
else: |
|
|
|
raise ValueError("dir %s does not exist" % self.root) |
|
|
|
|
|
|
|
self.target_file = os.path.join(self.root, self.raw_file_dir) |
|
|
|
|
|
|
@@ -77,8 +84,7 @@ class CIFAR10(VisionDataset): |
|
|
|
self.arrays = self.bytes2array(self.test_batch) |
|
|
|
else: |
|
|
|
raise ValueError( |
|
|
|
"dir does not contain target file\ |
|
|
|
%s, please set download=True" |
|
|
|
"dir does not contain target file %s, please set download=True" |
|
|
|
% (self.target_file) |
|
|
|
) |
|
|
|
|
|
|
@@ -160,6 +166,6 @@ class CIFAR100(CIFAR10): |
|
|
|
data.extend(list(batch_data[..., [2, 1, 0]])) |
|
|
|
fine_label.extend(dic[b"fine_labels"]) |
|
|
|
coarse_label.extend(dic[b"coarse_labels"]) |
|
|
|
fine_label = np.array(fine_label) |
|
|
|
coarse_label = np.array(coarse_label) |
|
|
|
fine_label = np.array(fine_label, dtype=np.int32) |
|
|
|
coarse_label = np.array(coarse_label, dtype=np.int32) |
|
|
|
return data, fine_label, coarse_label |