You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn-for-image.ipynb 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# RNN 做图像分类\n",
  8. "前面我们讲了 RNN 特别适合做序列类型的数据,那么 RNN 能不能想 CNN 一样用来做图像分类呢?下面我们用 mnist 手写字体的例子来展示一下如何用 RNN 做图像分类,但是这种方法并不是主流,这里我们只是作为举例。"
  9. ]
  10. },
  11. {
  12. "cell_type": "markdown",
  13. "metadata": {},
  14. "source": [
  15. "对于一张手写字体的图片,其大小是 28 * 28,我们可以将其看做是一个长为 28 的序列,每个序列的特征都是 28,也就是\n",
  16. "\n",
  17. "![](https://ws4.sinaimg.cn/large/006tKfTcly1fmu7d0byfkj30n60djdg5.jpg)"
  18. ]
  19. },
  20. {
  21. "cell_type": "markdown",
  22. "metadata": {},
  23. "source": [
  24. "这样我们解决了输入序列的问题,对于输出序列怎么办呢?其实非常简单,虽然我们的输出是一个序列,但是我们只需要保留其中一个作为输出结果就可以了,这样的话肯定保留最后一个结果是最好的,因为最后一个结果有前面所有序列的信息,就像下面这样\n",
  25. "\n",
  26. "![](https://ws3.sinaimg.cn/large/006tKfTcly1fmu7fpqri0j30c407yjr8.jpg)\n",
  27. "\n",
  28. "下面我们直接通过例子展示"
  29. ]
  30. },
  31. {
  32. "cell_type": "code",
  33. "execution_count": 1,
  34. "metadata": {
  35. "ExecuteTime": {
  36. "end_time": "2017-12-26T08:01:44.502896Z",
  37. "start_time": "2017-12-26T08:01:44.062542Z"
  38. },
  39. "collapsed": true
  40. },
  41. "outputs": [],
  42. "source": [
  43. "import sys\n",
  44. "sys.path.append('..')\n",
  45. "\n",
  46. "import torch\n",
  47. "from torch.autograd import Variable\n",
  48. "from torch import nn\n",
  49. "from torch.utils.data import DataLoader\n",
  50. "\n",
  51. "from torchvision import transforms as tfs\n",
  52. "from torchvision.datasets import MNIST"
  53. ]
  54. },
  55. {
  56. "cell_type": "code",
  57. "execution_count": 2,
  58. "metadata": {
  59. "ExecuteTime": {
  60. "end_time": "2017-12-26T08:01:50.714439Z",
  61. "start_time": "2017-12-26T08:01:50.650872Z"
  62. },
  63. "collapsed": true
  64. },
  65. "outputs": [],
  66. "source": [
  67. "# 定义数据\n",
  68. "data_tf = tfs.Compose([\n",
  69. " tfs.ToTensor(),\n",
  70. " tfs.Normalize([0.5], [0.5]) # 标准化\n",
  71. "])\n",
  72. "\n",
  73. "train_set = MNIST('./data', train=True, transform=data_tf)\n",
  74. "test_set = MNIST('./data', train=False, transform=data_tf)\n",
  75. "\n",
  76. "train_data = DataLoader(train_set, 64, True, num_workers=4)\n",
  77. "test_data = DataLoader(test_set, 128, False, num_workers=4)"
  78. ]
  79. },
  80. {
  81. "cell_type": "code",
  82. "execution_count": 3,
  83. "metadata": {
  84. "ExecuteTime": {
  85. "end_time": "2017-12-26T08:01:51.165144Z",
  86. "start_time": "2017-12-26T08:01:51.115807Z"
  87. },
  88. "collapsed": true
  89. },
  90. "outputs": [],
  91. "source": [
  92. "# 定义模型\n",
  93. "class rnn_classify(nn.Module):\n",
  94. " def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):\n",
  95. " super(rnn_classify, self).__init__()\n",
  96. " self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers) # 使用两层 lstm\n",
  97. " self.classifier = nn.Linear(hidden_feature, num_class) # 将最后一个 rnn 的输出使用全连接得到最后的分类结果\n",
  98. " \n",
  99. " def forward(self, x):\n",
  100. " '''\n",
  101. " x 大小为 (batch, 1, 28, 28),所以我们需要将其转换成 RNN 的输入形式,即 (28, batch, 28)\n",
  102. " '''\n",
  103. " x = x.squeeze() # 去掉 (batch, 1, 28, 28) 中的 1,变成 (batch, 28, 28)\n",
  104. " x = x.permute(2, 0, 1) # 将最后一维放到第一维,变成 (28, batch, 28)\n",
  105. " out, _ = self.rnn(x) # 使用默认的隐藏状态,得到的 out 是 (28, batch, hidden_feature)\n",
  106. " out = out[-1, :, :] # 取序列中的最后一个,大小是 (batch, hidden_feature)\n",
  107. " out = self.classifier(out) # 得到分类结果\n",
  108. " return out"
  109. ]
  110. },
  111. {
  112. "cell_type": "code",
  113. "execution_count": 4,
  114. "metadata": {
  115. "ExecuteTime": {
  116. "end_time": "2017-12-26T08:01:51.252533Z",
  117. "start_time": "2017-12-26T08:01:51.244612Z"
  118. },
  119. "collapsed": true
  120. },
  121. "outputs": [],
  122. "source": [
  123. "net = rnn_classify()\n",
  124. "criterion = nn.CrossEntropyLoss()\n",
  125. "\n",
  126. "optimzier = torch.optim.Adadelta(net.parameters(), 1e-1)"
  127. ]
  128. },
  129. {
  130. "cell_type": "code",
  131. "execution_count": 5,
  132. "metadata": {
  133. "ExecuteTime": {
  134. "end_time": "2017-12-26T08:03:36.739732Z",
  135. "start_time": "2017-12-26T08:01:51.607967Z"
  136. }
  137. },
  138. "outputs": [
  139. {
  140. "name": "stdout",
  141. "output_type": "stream",
  142. "text": [
  143. "Epoch 0. Train Loss: 1.858605, Train Acc: 0.318347, Valid Loss: 1.147508, Valid Acc: 0.578125, Time 00:00:09\n",
  144. "Epoch 1. Train Loss: 0.503072, Train Acc: 0.848514, Valid Loss: 0.300552, Valid Acc: 0.912579, Time 00:00:09\n",
  145. "Epoch 2. Train Loss: 0.224762, Train Acc: 0.934785, Valid Loss: 0.176321, Valid Acc: 0.946499, Time 00:00:09\n",
  146. "Epoch 3. Train Loss: 0.157010, Train Acc: 0.953392, Valid Loss: 0.155280, Valid Acc: 0.954015, Time 00:00:09\n",
  147. "Epoch 4. Train Loss: 0.125926, Train Acc: 0.962137, Valid Loss: 0.105295, Valid Acc: 0.969640, Time 00:00:09\n",
  148. "Epoch 5. Train Loss: 0.104938, Train Acc: 0.968450, Valid Loss: 0.091477, Valid Acc: 0.972805, Time 00:00:10\n",
  149. "Epoch 6. Train Loss: 0.089124, Train Acc: 0.973481, Valid Loss: 0.104799, Valid Acc: 0.969343, Time 00:00:09\n",
  150. "Epoch 7. Train Loss: 0.077920, Train Acc: 0.976679, Valid Loss: 0.084242, Valid Acc: 0.976661, Time 00:00:10\n",
  151. "Epoch 8. Train Loss: 0.070259, Train Acc: 0.978795, Valid Loss: 0.078536, Valid Acc: 0.977749, Time 00:00:09\n",
  152. "Epoch 9. Train Loss: 0.063089, Train Acc: 0.981093, Valid Loss: 0.066984, Valid Acc: 0.980716, Time 00:00:09\n"
  153. ]
  154. }
  155. ],
  156. "source": [
  157. "# 开始训练\n",
  158. "from utils import train\n",
  159. "train(net, train_data, test_data, 10, optimzier, criterion)"
  160. ]
  161. },
  162. {
  163. "cell_type": "markdown",
  164. "metadata": {},
  165. "source": [
  166. "可以看到,训练 10 次在简单的 mnist 数据集上也取得的了 98% 的准确率,所以说 RNN 也可以做做简单的图像分类,但是这并不是他的主战场,下次课我们会讲到 RNN 的一个使用场景,时间序列预测。"
  167. ]
  168. }
  169. ],
  170. "metadata": {
  171. "kernelspec": {
  172. "display_name": "Python 3",
  173. "language": "python",
  174. "name": "python3"
  175. },
  176. "language_info": {
  177. "codemirror_mode": {
  178. "name": "ipython",
  179. "version": 3
  180. },
  181. "file_extension": ".py",
  182. "mimetype": "text/x-python",
  183. "name": "python",
  184. "nbconvert_exporter": "python",
  185. "pygments_lexer": "ipython3",
  186. "version": "3.7.9"
  187. }
  188. },
  189. "nbformat": 4,
  190. "nbformat_minor": 2
  191. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。