You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

3_CNN_CIFAR.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import torch as t
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch import optim
  5. from torch.autograd import Variable
  6. import torchvision as tv
  7. import torchvision.transforms as transforms
  8. from torchvision.transforms import ToPILImage
  9. show = ToPILImage() # 可以把Tensor转成Image,方便可视化
  10. # 第一次运行程序torchvision会自动下载CIFAR-10数据集,
  11. # 大约100M,需花费一定的时间,
  12. # 如果已经下载有CIFAR-10,可通过root参数指定
  13. # 定义对数据的预处理
  14. transform = transforms.Compose([
  15. transforms.ToTensor(), # 转为Tensor
  16. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化
  17. ])
  18. # 训练集
  19. dataset_path = "../data"
  20. trainset = tv.datasets.CIFAR10(
  21. root=dataset_path, train=True, download=True, transform=transform)
  22. trainloader = t.utils.data.DataLoader(
  23. trainset,
  24. batch_size=4,
  25. shuffle=True,
  26. num_workers=2)
  27. # 测试集
  28. testset = tv.datasets.CIFAR10(
  29. root=dataset_path, train=False, download=True, transform=transform)
  30. testloader = t.utils.data.DataLoader(
  31. testset,
  32. batch_size=4,
  33. shuffle=False,
  34. num_workers=2)
  35. classes = ('plane', 'car', 'bird', 'cat', 'deer',
  36. 'dog', 'frog', 'horse', 'ship', 'truck')
  37. # Define the network
  38. class Net(nn.Module):
  39. def __init__(self):
  40. super(Net, self).__init__()
  41. self.conv1 = nn.Conv2d(3, 6, 5)
  42. self.conv2 = nn.Conv2d(6, 16, 5)
  43. self.fc1 = nn.Linear(16*5*5, 120)
  44. self.fc2 = nn.Linear(120, 84)
  45. self.fc3 = nn.Linear(84, 10)
  46. def forward(self, x):
  47. x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
  48. x = F.max_pool2d(F.relu(self.conv2(x)), 2)
  49. x = x.view(x.size()[0], -1)
  50. x = F.relu(self.fc1(x))
  51. x = F.relu(self.fc2(x))
  52. x = self.fc3(x)
  53. return x
  54. net = Net()
  55. print(net)
  56. criterion = nn.CrossEntropyLoss()
  57. optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  58. t.set_num_threads(8)
  59. for epoch in range(2):
  60. running_loss = 0.0
  61. for i, data in enumerate(trainloader, 0):
  62. # 输入数据
  63. inputs, labels = data
  64. inputs, labels = Variable(inputs), Variable(labels)
  65. # 梯度清零
  66. optimizer.zero_grad()
  67. # forward + backward
  68. outputs = net(inputs)
  69. loss = criterion(outputs, labels)
  70. loss.backward()
  71. # 更新参数
  72. optimizer.step()
  73. # 打印log信息
  74. running_loss += loss.data[0]
  75. if i % 2000 == 1999: # 每2000个batch打印一下训练状态
  76. print('[%d, %5d] loss: %.3f' \
  77. % (epoch + 1, i + 1, running_loss / 2000))
  78. running_loss = 0.0
  79. print('Finished Training')
  80. dataiter = iter(testloader)
  81. images, labels = dataiter.next() # 一个batch返回4张图片
  82. print('实际的label: ', ' '.join(\
  83. '%08s'%classes[labels[j]] for j in range(4)))
  84. show(tv.utils.make_grid(images / 2 - 0.5)).resize((400,100))
  85. # 计算图片在每个类别上的分数
  86. outputs = net(Variable(images))
  87. # 得分最高的那个类
  88. _, predicted = t.max(outputs.data, 1)
  89. print('预测结果: ', ' '.join('%5s'\
  90. % classes[predicted[j]] for j in range(4)))
  91. correct = 0 # 预测正确的图片数
  92. total = 0 # 总共的图片数
  93. for data in testloader:
  94. images, labels = data
  95. outputs = net(Variable(images))
  96. _, predicted = t.max(outputs.data, 1)
  97. total += labels.size(0)
  98. correct += (predicted == labels).sum()
  99. print('10000张测试集中的准确率为: %d %%' % (100 * correct / total))

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。