import os import torch as t import torchvision as tv import torchvision.transforms as transforms from torchvision.transforms import ToPILImage show = ToPILImage() # 可以把Tensor转成Image,方便可视化 # 第一次运行程序torchvision会自动下载CIFAR-10数据集, # 大约100M,需花费一定的时间, # 如果已经下载有CIFAR-10,可通过root参数指定 # 定义对数据的预处理 transform = transforms.Compose([ transforms.ToTensor(), # 转为Tensor transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化 ]) # set data storage dir & check whether do download data_path = "../data/" p = os.path.join(data_path, "cifar-10-batches-py") do_download = True if os.path.isdir(p): do_download = False # 训练集 trainset = tv.datasets.CIFAR10( root=data_path, train=True, download=do_download, transform=transform) trainloader = t.utils.data.DataLoader( trainset, batch_size=4, shuffle=True, num_workers=2) # 测试集 testset = tv.datasets.CIFAR10( root=data_path, train=False, download=do_download, transform=transform) testloader = t.utils.data.DataLoader( testset, batch_size=4, shuffle=False, num_workers=2) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')