You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data-augumentation.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # -*- coding: utf-8 -*-
  2. # ---
  3. # jupyter:
  4. # jupytext_format_version: '1.2'
  5. # kernelspec:
  6. # display_name: Python 3
  7. # language: python
  8. # name: python3
  9. # language_info:
  10. # codemirror_mode:
  11. # name: ipython
  12. # version: 3
  13. # file_extension: .py
  14. # mimetype: text/x-python
  15. # name: python
  16. # nbconvert_exporter: python
  17. # pygments_lexer: ipython3
  18. # version: 3.5.2
  19. # ---
  20. # # 数据增强
  21. # 前面我们已经讲了几个非常著名的卷积网络的结构,但是单单只靠这些网络并不能取得 state-of-the-art 的结果,现实问题往往更加复杂,非常容易出现过拟合的问题,而数据增强的方法是对抗过拟合问题的一个重要方法。
  22. #
  23. # 2012 年 AlexNet 在 ImageNet 上大获全胜,图片增强方法功不可没,因为有了图片增强,使得训练的数据集比实际数据集多了很多'新'样本,减少了过拟合的问题,下面我们来具体解释一下。
  24. # ## 常用的数据增强方法
  25. # 常用的数据增强方法如下:
  26. # 1.对图片进行一定比例缩放
  27. # 2.对图片进行随机位置的截取
  28. # 3.对图片进行随机的水平和竖直翻转
  29. # 4.对图片进行随机角度的旋转
  30. # 5.对图片进行亮度、对比度和颜色的随机变化
  31. #
  32. # 这些方法 pytorch 都已经为我们内置在了 torchvision 里面,我们在安装 pytorch 的时候也安装了 torchvision,下面我们来依次展示一下这些数据增强方法
  33. # +
  34. import sys
  35. sys.path.append('..')
  36. from PIL import Image
  37. from torchvision import transforms as tfs
  38. # -
  39. # 读入一张图片
  40. im = Image.open('./cat.png')
  41. im
  42. # ### 随机比例放缩
  43. # 随机比例缩放主要使用的是 `torchvision.transforms.Resize()` 这个函数,第一个参数可以是一个整数,那么图片会保存现在的宽和高的比例,并将更短的边缩放到这个整数的大小,第一个参数也可以是一个 tuple,那么图片会直接把宽和高缩放到这个大小;第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值,你可以手动去改这个参数,更多的信息可以看看[文档](http://pytorch.org/docs/0.3.0/torchvision/transforms.html)
  44. # 比例缩放
  45. print('before scale, shape: {}'.format(im.size))
  46. new_im = tfs.Resize((100, 200))(im)
  47. print('after scale, shape: {}'.format(new_im.size))
  48. new_im
  49. # ### 随机位置截取
  50. # 随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision 中主要有下面两种方式,一个是 `torchvision.transforms.RandomCrop()`,传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 `torchvision.transforms.CenterCrop()`,同样传入介曲初的图片的大小作为参数,会在图片的中心进行截取
  51. # 随机裁剪出 100 x 100 的区域
  52. random_im1 = tfs.RandomCrop(100)(im)
  53. random_im1
  54. # 随机裁剪出 150 x 100 的区域
  55. random_im2 = tfs.RandomCrop((150, 100))(im)
  56. random_im2
  57. # 中心裁剪出 100 x 100 的区域
  58. center_im = tfs.CenterCrop(100)(im)
  59. center_im
  60. # ### 随机的水平和竖直方向翻转
  61. # 对于上面这一张猫的图片,如果我们将它翻转一下,它仍然是一张猫,但是图片就有了更多的多样性,所以随机翻转也是一种非常有效的手段。在 torchvision 中,随机翻转使用的是 `torchvision.transforms.RandomHorizontalFlip()` 和 `torchvision.transforms.RandomVerticalFlip()`
  62. # 随机水平翻转
  63. h_filp = tfs.RandomHorizontalFlip()(im)
  64. h_filp
  65. # 随机竖直翻转
  66. v_flip = tfs.RandomVerticalFlip()(im)
  67. v_flip
  68. # ### 随机角度旋转
  69. # 一些角度的旋转仍然是非常有用的数据增强方式,在 torchvision 中,使用 `torchvision.transforms.RandomRotation()` 来实现,其中第一个参数就是随机旋转的角度,比如填入 10,那么每次图片就会在 -10 ~ 10 度之间随机旋转
  70. rot_im = tfs.RandomRotation(45)(im)
  71. rot_im
  72. # ### 亮度、对比度和颜色的变化
  73. # 除了形状变化外,颜色变化又是另外一种增强方式,其中可以设置亮度变化,对比度变化和颜色变化等,在 torchvision 中主要使用 `torchvision.transforms.ColorJitter()` 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色
  74. # 亮度
  75. bright_im = tfs.ColorJitter(brightness=1)(im) # 随机从 0 ~ 2 之间亮度变化,1 表示原图
  76. bright_im
  77. # 对比度
  78. contrast_im = tfs.ColorJitter(contrast=1)(im) # 随机从 0 ~ 2 之间对比度变化,1 表示原图
  79. contrast_im
  80. # 颜色
  81. color_im = tfs.ColorJitter(hue=0.5)(im) # 随机从 -0.5 ~ 0.5 之间对颜色变化
  82. color_im
  83. #
  84. #
  85. # 上面我们讲了这么图片增强的方法,其实这些方法都不是孤立起来用的,可以联合起来用,比如先做随机翻转,然后随机截取,再做对比度增强等等,torchvision 里面有个非常方便的函数能够将这些变化合起来,就是 `torchvision.transforms.Compose()`,下面我们举个例子
  86. im_aug = tfs.Compose([
  87. tfs.Resize(120),
  88. tfs.RandomHorizontalFlip(),
  89. tfs.RandomCrop(96),
  90. tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)
  91. ])
  92. import matplotlib.pyplot as plt
  93. # %matplotlib inline
  94. nrows = 3
  95. ncols = 3
  96. figsize = (8, 8)
  97. _, figs = plt.subplots(nrows, ncols, figsize=figsize)
  98. for i in range(nrows):
  99. for j in range(ncols):
  100. figs[i][j].imshow(im_aug(im))
  101. figs[i][j].axes.get_xaxis().set_visible(False)
  102. figs[i][j].axes.get_yaxis().set_visible(False)
  103. plt.show()
  104. # 可以看到每次做完增强之后的图片都有一些变化,所以这就是我们前面讲的,增加了一些'新'数据
  105. #
  106. # 下面我们使用图像增强进行训练网络,看看具体的提升究竟在什么地方,使用前面讲的 ResNet 进行训练
  107. # + {"ExecuteTime": {"start_time": "2017-12-23T05:04:02.920639Z", "end_time": "2017-12-23T05:04:03.407434Z"}}
  108. import numpy as np
  109. import torch
  110. from torch import nn
  111. import torch.nn.functional as F
  112. from torch.autograd import Variable
  113. from torchvision.datasets import CIFAR10
  114. from utils import train, resnet
  115. from torchvision import transforms as tfs
  116. # + {"ExecuteTime": {"start_time": "2017-12-23T05:04:03.459562Z", "end_time": "2017-12-23T05:04:04.743167Z"}}
  117. # 使用数据增强
  118. def train_tf(x):
  119. im_aug = tfs.Compose([
  120. tfs.Resize(120),
  121. tfs.RandomHorizontalFlip(),
  122. tfs.RandomCrop(96),
  123. tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
  124. tfs.ToTensor(),
  125. tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  126. ])
  127. x = im_aug(x)
  128. return x
  129. def test_tf(x):
  130. im_aug = tfs.Compose([
  131. tfs.Resize(96),
  132. tfs.ToTensor(),
  133. tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  134. ])
  135. x = im_aug(x)
  136. return x
  137. train_set = CIFAR10('./data', train=True, transform=train_tf)
  138. train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
  139. test_set = CIFAR10('./data', train=False, transform=test_tf)
  140. test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
  141. net = resnet(3, 10)
  142. optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
  143. criterion = nn.CrossEntropyLoss()
  144. # + {"ExecuteTime": {"start_time": "2017-12-23T05:04:04.745540Z", "end_time": "2017-12-23T05:08:51.433955Z"}}
  145. train(net, train_data, test_data, 10, optimizer, criterion)
  146. # + {"ExecuteTime": {"start_time": "2017-12-23T05:09:21.756986Z", "end_time": "2017-12-23T05:09:22.997927Z"}}
  147. # 不使用数据增强
  148. def data_tf(x):
  149. im_aug = tfs.Compose([
  150. tfs.Resize(96),
  151. tfs.ToTensor(),
  152. tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  153. ])
  154. x = im_aug(x)
  155. return x
  156. train_set = CIFAR10('./data', train=True, transform=data_tf)
  157. train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
  158. test_set = CIFAR10('./data', train=False, transform=data_tf)
  159. test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
  160. net = resnet(3, 10)
  161. optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
  162. criterion = nn.CrossEntropyLoss()
  163. # + {"ExecuteTime": {"start_time": "2017-12-23T05:09:23.000573Z", "end_time": "2017-12-23T05:13:57.898751Z"}}
  164. train(net, train_data, test_data, 10, optimizer, criterion)
  165. # -
  166. # 从上面可以看出,对于训练集,不做数据增强跑 10 次,准确率已经到了 95%,而使用了数据增强,跑 10 次准确率只有 75%,说明数据增强之后变得更难了。
  167. #
  168. # 而对于测试集,使用数据增强进行训练的时候,准确率会比不使用更高,因为数据增强提高了模型应对于更多的不同数据集的泛化能力,所以有更好的效果。

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。