You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

densenet.py 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # -*- coding: utf-8 -*-
  2. # ---
  3. # jupyter:
  4. # jupytext_format_version: '1.2'
  5. # kernelspec:
  6. # display_name: Python 3
  7. # language: python
  8. # name: python3
  9. # language_info:
  10. # codemirror_mode:
  11. # name: ipython
  12. # version: 3
  13. # file_extension: .py
  14. # mimetype: text/x-python
  15. # name: python
  16. # nbconvert_exporter: python
  17. # pygments_lexer: ipython3
  18. # version: 3.5.2
  19. # ---
  20. # # DenseNet
  21. # 因为 ResNet 提出了跨层链接的思想,这直接影响了随后出现的卷积网络架构,其中最有名的就是 cvpr 2017 的 best paper,DenseNet。
  22. #
  23. # DenseNet 和 ResNet 不同在于 ResNet 是跨层求和,而 DenseNet 是跨层将特征在通道维度进行拼接,下面可以看看他们两者的图示
  24. #
  25. # ![](https://ws4.sinaimg.cn/large/006tNc79ly1fmpvj5vkfhj30uw0anq73.jpg)
  26. #
  27. # ![](https://ws1.sinaimg.cn/large/006tNc79ly1fmpvj7fxd1j30vb0eyzqf.jpg)
  28. # 第一张图是 ResNet,第二张图是 DenseNet,因为是在通道维度进行特征的拼接,所以底层的输出会保留进入所有后面的层,这能够更好的保证梯度的传播,同时能够使用低维的特征和高维的特征进行联合训练,能够得到更好的结果。
  29. # DenseNet 主要由 dense block 构成,下面我们来实现一个 densen block
  30. # + {"ExecuteTime": {"start_time": "2017-12-22T15:38:30.612922Z", "end_time": "2017-12-22T15:38:31.113030Z"}}
  31. import sys
  32. sys.path.append('..')
  33. import numpy as np
  34. import torch
  35. from torch import nn
  36. from torch.autograd import Variable
  37. from torchvision.datasets import CIFAR10
  38. # -
  39. # 首先定义一个卷积块,这个卷积块的顺序是 bn -> relu -> conv
  40. # + {"ExecuteTime": {"start_time": "2017-12-22T15:38:31.115369Z", "end_time": "2017-12-22T15:38:31.121249Z"}}
  41. def conv_block(in_channel, out_channel):
  42. layer = nn.Sequential(
  43. nn.BatchNorm2d(in_channel),
  44. nn.ReLU(True),
  45. nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False)
  46. )
  47. return layer
  48. # -
  49. # dense block 将每次的卷积的输出称为 `growth_rate`,因为如果输入是 `in_channel`,有 n 层,那么输出就是 `in_channel + n * growh_rate`
  50. # + {"ExecuteTime": {"start_time": "2017-12-22T15:38:31.123363Z", "end_time": "2017-12-22T15:38:31.145274Z"}}
  51. class dense_block(nn.Module):
  52. def __init__(self, in_channel, growth_rate, num_layers):
  53. super(dense_block, self).__init__()
  54. block = []
  55. channel = in_channel
  56. for i in range(num_layers):
  57. block.append(conv_block(channel, growth_rate))
  58. channel += growth_rate
  59. self.net = nn.Sequential(*block)
  60. def forward(self, x):
  61. for layer in self.net:
  62. out = layer(x)
  63. x = torch.cat((out, x), dim=1)
  64. return x
  65. # -
  66. # 我们验证一下输出的 channel 是否正确
  67. # + {"ExecuteTime": {"start_time": "2017-12-22T15:38:31.147196Z", "end_time": "2017-12-22T15:38:31.213632Z"}}
  68. test_net = dense_block(3, 12, 3)
  69. test_x = Variable(torch.zeros(1, 3, 96, 96))
  70. print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))
  71. test_y = test_net(test_x)
  72. print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))
  73. # -
  74. # 除了 dense block,DenseNet 中还有一个模块叫过渡层(transition block),因为 DenseNet 会不断地对维度进行拼接, 所以当层数很高的时候,输出的通道数就会越来越大,参数和计算量也会越来越大,为了避免这个问题,需要引入过渡层将输出通道降低下来,同时也将输入的长宽减半,这个过渡层可以使用 1 x 1 的卷积
  75. # + {"ExecuteTime": {"start_time": "2017-12-22T15:38:31.215770Z", "end_time": "2017-12-22T15:38:31.222120Z"}}
  76. def transition(in_channel, out_channel):
  77. trans_layer = nn.Sequential(
  78. nn.BatchNorm2d(in_channel),
  79. nn.ReLU(True),
  80. nn.Conv2d(in_channel, out_channel, 1),
  81. nn.AvgPool2d(2, 2)
  82. )
  83. return trans_layer
  84. # -
  85. # 验证一下过渡层是否正确
  86. # + {"ExecuteTime": {"start_time": "2017-12-22T15:38:31.224078Z", "end_time": "2017-12-22T15:38:31.234846Z"}}
  87. test_net = transition(3, 12)
  88. test_x = Variable(torch.zeros(1, 3, 96, 96))
  89. print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))
  90. test_y = test_net(test_x)
  91. print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))
  92. # -
  93. # 最后我们定义 DenseNet
  94. # + {"ExecuteTime": {"start_time": "2017-12-22T15:38:31.236857Z", "end_time": "2017-12-22T15:38:31.318822Z"}}
  95. class densenet(nn.Module):
  96. def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):
  97. super(densenet, self).__init__()
  98. self.block1 = nn.Sequential(
  99. nn.Conv2d(in_channel, 64, 7, 2, 3),
  100. nn.BatchNorm2d(64),
  101. nn.ReLU(True),
  102. nn.MaxPool2d(3, 2, padding=1)
  103. )
  104. channels = 64
  105. block = []
  106. for i, layers in enumerate(block_layers):
  107. block.append(dense_block(channels, growth_rate, layers))
  108. channels += layers * growth_rate
  109. if i != len(block_layers) - 1:
  110. block.append(transition(channels, channels // 2)) # 通过 transition 层将大小减半,通道数减半
  111. channels = channels // 2
  112. self.block2 = nn.Sequential(*block)
  113. self.block2.add_module('bn', nn.BatchNorm2d(channels))
  114. self.block2.add_module('relu', nn.ReLU(True))
  115. self.block2.add_module('avg_pool', nn.AvgPool2d(3))
  116. self.classifier = nn.Linear(channels, num_classes)
  117. def forward(self, x):
  118. x = self.block1(x)
  119. x = self.block2(x)
  120. x = x.view(x.shape[0], -1)
  121. x = self.classifier(x)
  122. return x
  123. # + {"ExecuteTime": {"start_time": "2017-12-22T15:38:31.320788Z", "end_time": "2017-12-22T15:38:31.654182Z"}}
  124. test_net = densenet(3, 10)
  125. test_x = Variable(torch.zeros(1, 3, 96, 96))
  126. test_y = test_net(test_x)
  127. print('output: {}'.format(test_y.shape))
  128. # + {"ExecuteTime": {"start_time": "2017-12-22T15:38:31.656356Z", "end_time": "2017-12-22T15:38:32.894729Z"}}
  129. from utils import train
  130. def data_tf(x):
  131. x = x.resize((96, 96), 2) # 将图片放大到 96 x 96
  132. x = np.array(x, dtype='float32') / 255
  133. x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到
  134. x = x.transpose((2, 0, 1)) # 将 channel 放到第一维,只是 pytorch 要求的输入方式
  135. x = torch.from_numpy(x)
  136. return x
  137. train_set = CIFAR10('./data', train=True, transform=data_tf)
  138. train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
  139. test_set = CIFAR10('./data', train=False, transform=data_tf)
  140. test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
  141. net = densenet(3, 10)
  142. optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
  143. criterion = nn.CrossEntropyLoss()
  144. # + {"ExecuteTime": {"start_time": "2017-12-22T15:38:32.896735Z", "end_time": "2017-12-22T16:15:38.168095Z"}}
  145. train(net, train_data, test_data, 20, optimizer, criterion)
  146. # -
  147. # DenseNet 将残差连接改为了特征拼接,使得网络有了更稠密的连接

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。