{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# RNN 做图像分类\n", "前面我们讲了 RNN 特别适合做序列类型的数据,那么 RNN 能不能想 CNN 一样用来做图像分类呢?下面我们用 mnist 手写字体的例子来展示一下如何用 RNN 做图像分类,但是这种方法并不是主流,这里我们只是作为举例。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "对于一张手写字体的图片,其大小是 28 * 28,我们可以将其看做是一个长为 28 的序列,每个序列的特征都是 28,也就是\n", "\n", "![](https://ws4.sinaimg.cn/large/006tKfTcly1fmu7d0byfkj30n60djdg5.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这样我们解决了输入序列的问题,对于输出序列怎么办呢?其实非常简单,虽然我们的输出是一个序列,但是我们只需要保留其中一个作为输出结果就可以了,这样的话肯定保留最后一个结果是最好的,因为最后一个结果有前面所有序列的信息,就像下面这样\n", "\n", "![](https://ws3.sinaimg.cn/large/006tKfTcly1fmu7fpqri0j30c407yjr8.jpg)\n", "\n", "下面我们直接通过例子展示" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2017-12-26T08:01:44.502896Z", "start_time": "2017-12-26T08:01:44.062542Z" }, "collapsed": true }, "outputs": [], "source": [ "import sys\n", "sys.path.append('..')\n", "\n", "import torch\n", "from torch.autograd import Variable\n", "from torch import nn\n", "from torch.utils.data import DataLoader\n", "\n", "from torchvision import transforms as tfs\n", "from torchvision.datasets import MNIST" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2017-12-26T08:01:50.714439Z", "start_time": "2017-12-26T08:01:50.650872Z" }, "collapsed": true }, "outputs": [], "source": [ "# 定义数据\n", "data_tf = tfs.Compose([\n", " tfs.ToTensor(),\n", " tfs.Normalize([0.5], [0.5]) # 标准化\n", "])\n", "\n", "train_set = MNIST('./data', train=True, transform=data_tf)\n", "test_set = MNIST('./data', train=False, transform=data_tf)\n", "\n", "train_data = DataLoader(train_set, 64, True, num_workers=4)\n", "test_data = DataLoader(test_set, 128, False, num_workers=4)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2017-12-26T08:01:51.165144Z", "start_time": "2017-12-26T08:01:51.115807Z" }, "collapsed": true }, "outputs": [], "source": [ "# 定义模型\n", "class rnn_classify(nn.Module):\n", " def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):\n", " super(rnn_classify, self).__init__()\n", " self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers) # 使用两层 lstm\n", " self.classifier = nn.Linear(hidden_feature, num_class) # 将最后一个 rnn 的输出使用全连接得到最后的分类结果\n", " \n", " def forward(self, x):\n", " '''\n", " x 大小为 (batch, 1, 28, 28),所以我们需要将其转换成 RNN 的输入形式,即 (28, batch, 28)\n", " '''\n", " x = x.squeeze() # 去掉 (batch, 1, 28, 28) 中的 1,变成 (batch, 28, 28)\n", " x = x.permute(2, 0, 1) # 将最后一维放到第一维,变成 (28, batch, 28)\n", " out, _ = self.rnn(x) # 使用默认的隐藏状态,得到的 out 是 (28, batch, hidden_feature)\n", " out = out[-1, :, :] # 取序列中的最后一个,大小是 (batch, hidden_feature)\n", " out = self.classifier(out) # 得到分类结果\n", " return out" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2017-12-26T08:01:51.252533Z", "start_time": "2017-12-26T08:01:51.244612Z" }, "collapsed": true }, "outputs": [], "source": [ "net = rnn_classify()\n", "criterion = nn.CrossEntropyLoss()\n", "\n", "optimzier = torch.optim.Adadelta(net.parameters(), 1e-1)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2017-12-26T08:03:36.739732Z", "start_time": "2017-12-26T08:01:51.607967Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0. Train Loss: 1.858605, Train Acc: 0.318347, Valid Loss: 1.147508, Valid Acc: 0.578125, Time 00:00:09\n", "Epoch 1. Train Loss: 0.503072, Train Acc: 0.848514, Valid Loss: 0.300552, Valid Acc: 0.912579, Time 00:00:09\n", "Epoch 2. Train Loss: 0.224762, Train Acc: 0.934785, Valid Loss: 0.176321, Valid Acc: 0.946499, Time 00:00:09\n", "Epoch 3. Train Loss: 0.157010, Train Acc: 0.953392, Valid Loss: 0.155280, Valid Acc: 0.954015, Time 00:00:09\n", "Epoch 4. Train Loss: 0.125926, Train Acc: 0.962137, Valid Loss: 0.105295, Valid Acc: 0.969640, Time 00:00:09\n", "Epoch 5. Train Loss: 0.104938, Train Acc: 0.968450, Valid Loss: 0.091477, Valid Acc: 0.972805, Time 00:00:10\n", "Epoch 6. Train Loss: 0.089124, Train Acc: 0.973481, Valid Loss: 0.104799, Valid Acc: 0.969343, Time 00:00:09\n", "Epoch 7. Train Loss: 0.077920, Train Acc: 0.976679, Valid Loss: 0.084242, Valid Acc: 0.976661, Time 00:00:10\n", "Epoch 8. Train Loss: 0.070259, Train Acc: 0.978795, Valid Loss: 0.078536, Valid Acc: 0.977749, Time 00:00:09\n", "Epoch 9. Train Loss: 0.063089, Train Acc: 0.981093, Valid Loss: 0.066984, Valid Acc: 0.980716, Time 00:00:09\n" ] } ], "source": [ "# 开始训练\n", "from utils import train\n", "train(net, train_data, test_data, 10, optimzier, criterion)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,训练 10 次在简单的 mnist 数据集上也取得的了 98% 的准确率,所以说 RNN 也可以做做简单的图像分类,但是这并不是他的主战场,下次课我们会讲到 RNN 的一个使用场景,时间序列预测。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }