{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch 中的循环神经网络模块\n", "前面我们讲了循环神经网络的基础知识和网络结构,下面我们教大家如何在 pytorch 下构建循环神经网络,因为 pytorch 的动态图机制,使得循环神经网络非常方便。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 一般的 RNN\n", "\n", "![](https://ws1.sinaimg.cn/large/006tKfTcly1fmt9xz889xj30kb07nglo.jpg)\n", "\n", "对于最简单的 RNN,我们可以使用下面两种方式去调用,分别是 `torch.nn.RNNCell()` 和 `torch.nn.RNN()`,这两种方式的区别在于 `RNNCell()` 只能接受序列中单步的输入,且必须传入隐藏状态,而 `RNN()` 可以接受一个序列的输入,默认会传入全 0 的隐藏状态,也可以自己申明隐藏状态传入。\n", "\n", "`RNN()` 里面的参数有\n", "\n", "input_size 表示输入 $x_t$ 的特征维度\n", "\n", "hidden_size 表示输出的特征维度\n", "\n", "num_layers 表示网络的层数\n", "\n", "nonlinearity 表示选用的非线性激活函数,默认是 'tanh'\n", "\n", "bias 表示是否使用偏置,默认使用\n", "\n", "batch_first 表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位\n", "\n", "dropout 表示是否在输出层应用 dropout\n", "\n", "bidirectional 表示是否使用双向的 rnn,默认是 False\n", "\n", "对于 `RNNCell()`,里面的参数就少很多,只有 input_size,hidden_size,bias 以及 nonlinearity" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.autograd import Variable\n", "from torch import nn" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# 定义一个单步的 rnn\n", "rnn_single = nn.RNNCell(input_size=100, hidden_size=200)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Parameter containing:\n", "tensor([[-2.7963e-02, 3.6102e-02, 5.6609e-03, ..., -3.0035e-02,\n", " 2.7740e-02, 2.3327e-02],\n", " [-2.8567e-02, -3.2150e-02, -2.6686e-02, ..., -4.6441e-02,\n", " 3.5804e-02, 9.7260e-05],\n", " [ 4.6686e-02, -1.5825e-02, 6.7149e-02, ..., 3.3435e-02,\n", " -2.7623e-02, -6.7693e-02],\n", " ...,\n", " [-2.0338e-02, -1.6551e-02, 5.8996e-02, ..., -4.0145e-02,\n", " -6.9111e-03, -3.2740e-02],\n", " [-2.4584e-02, 2.3591e-02, 8.3090e-03, ..., -3.6077e-02,\n", " -6.0432e-03, 5.6279e-02],\n", " [ 5.6955e-02, -5.1925e-02, 3.1950e-02, ..., -5.6692e-02,\n", " 6.1773e-02, 1.9715e-02]], requires_grad=True)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 访问其中的参数\n", "rnn_single.weight_hh" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 1.4637, -2.0015, 0.6298, ..., -1.1210, -1.6310, 0.5122],\n", " [-0.1500, -0.6931, 0.1568, ..., -0.9185, -0.5088, -1.0746],\n", " [ 0.1717, 1.2186, -0.8093, ..., 0.8630, 0.4601, -1.0218],\n", " [-0.3034, 2.8634, 2.2470, ..., 0.1678, -2.0585, -0.9628],\n", " [-2.3764, -0.4235, -1.1760, ..., -1.2251, 0.6761, -1.0323]],\n", "\n", " [[-1.3497, -0.6778, -0.0528, ..., -0.1852, -0.3997, -0.7633],\n", " [ 1.0105, 0.7974, 0.4253, ..., -1.1167, -1.3870, -1.3583],\n", " [ 0.2785, 0.5013, -0.5881, ..., -0.0283, 0.6044, -0.3249],\n", " [-1.9298, -0.6575, -1.2878, ..., 0.5636, -0.3266, 1.9391],\n", " [ 1.3117, -1.1429, -1.5837, ..., -1.5248, -0.2046, 1.0696]],\n", "\n", " [[-0.8637, -1.0572, -0.2438, ..., 0.1011, -0.4630, 0.0526],\n", " [-0.0056, -0.9442, -0.5588, ..., -0.6881, -1.2189, -1.1846],\n", " [ 0.8341, 0.6924, -0.4376, ..., 1.1331, -0.9766, 1.3822],\n", " [-0.3815, -1.3457, 0.5320, ..., 0.8280, 0.2146, -0.8704],\n", " [-0.6424, 1.3608, -0.5325, ..., -0.3414, 1.0094, 1.2650]],\n", "\n", " [[-0.1776, -0.2037, -0.7093, ..., -1.1442, -1.0058, -0.6898],\n", " [ 0.2921, -1.9473, -0.6989, ..., 0.6852, -0.2225, -0.6484],\n", " [-0.8576, 1.9338, -1.5359, ..., -0.3545, -0.9438, 0.1476],\n", " [ 2.3669, 0.8673, 2.0521, ..., -0.4679, -0.4050, 0.7761],\n", " [ 0.3706, 1.2876, -0.5311, ..., 0.4794, -0.4209, 0.5343]],\n", "\n", " [[-0.2726, -1.2583, -0.8259, ..., 0.8811, 0.5900, 0.1770],\n", " [ 1.1066, -0.4899, 0.9143, ..., -2.2898, 0.1525, -2.2099],\n", " [-1.3824, 0.3142, 1.2140, ..., 0.5470, -0.4883, -0.3204],\n", " [ 1.8471, 0.6011, 0.0613, ..., 1.1584, -0.8014, 0.4891],\n", " [ 1.5201, -1.7853, 1.3107, ..., 0.0032, -1.3422, 0.7332]],\n", "\n", " [[ 0.3025, -0.7314, -0.2032, ..., -0.9658, -1.8131, 0.5922],\n", " [-0.0878, 0.0909, 0.7064, ..., 2.4186, -0.0863, 0.0930],\n", " [-1.4278, -1.0901, 1.6742, ..., 0.3020, -0.6106, -0.4299],\n", " [-1.8291, -1.1337, -0.2405, ..., -1.2000, 2.0510, 1.3617],\n", " [-2.7953, -0.0559, 1.0224, ..., 0.4400, 0.9099, -1.5845]]])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 构造一个序列,长为 6,batch 是 5, 特征是 100\n", "x = Variable(torch.randn(6, 5, 100)) # 这是 rnn 的输入格式\n", "x" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 定义初始的记忆状态\n", "h_t = Variable(torch.zeros(5, 200))" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 传入 rnn\n", "out = []\n", "for i in range(6): # 通过循环 6 次作用在整个序列上\n", " h_t = rnn_single(x[i], h_t)\n", " out.append(h_t)" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", " 0.0136 0.3723 0.1704 ... 0.4306 -0.7909 -0.5306\n", "-0.2681 -0.6261 -0.3926 ... 0.1752 0.5739 -0.2061\n", "-0.4918 -0.7611 0.2787 ... 0.0854 -0.3899 0.0092\n", " 0.6050 0.1852 -0.4261 ... -0.7220 0.6809 0.1825\n", "-0.6851 0.7273 0.5396 ... -0.7969 0.6133 -0.0852\n", "[torch.FloatTensor of size 5x200]" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "h_t" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(out)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 200])" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out[0].shape # 每个输出的维度" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到经过了 rnn 之后,隐藏状态的值已经被改变了,因为网络记忆了序列中的信息,同时输出 6 个结果" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们看看直接使用 `RNN` 的情况" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "collapsed": true }, "outputs": [], "source": [ "rnn_seq = nn.RNN(100, 200)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Parameter containing:\n", "1.00000e-02 *\n", " 1.0998 -1.5018 -1.4337 ... 3.8385 -0.8958 -1.6781\n", " 5.3302 -5.4654 5.5568 ... 4.7399 5.4110 3.6170\n", " 1.0788 -0.6620 5.7689 ... -5.0747 -2.9066 0.6152\n", " ... ⋱ ... \n", "-5.6921 0.1843 -0.0803 ... -4.5852 5.6194 -1.4734\n", " 4.4306 6.9795 -1.5736 ... 3.4236 -0.3441 3.1397\n", " 7.0349 -1.6120 -4.2840 ... -5.5676 6.8897 6.1968\n", "[torch.FloatTensor of size 200x200]" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 访问其中的参数\n", "rnn_seq.weight_hh_l0" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "collapsed": true }, "outputs": [], "source": [ "out, h_t = rnn_seq(x) # 使用默认的全 0 隐藏状态" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", "( 0 ,.,.) = \n", " 0.2012 0.0517 0.0570 ... 0.2316 0.3615 -0.1247\n", " 0.5307 0.4147 0.7881 ... -0.4138 -0.1444 0.3602\n", " 0.0882 0.4307 0.3939 ... 0.3244 -0.4629 -0.2315\n", " 0.2868 0.7400 0.6534 ... 0.6631 0.2624 -0.0162\n", " 0.0841 0.6274 0.1840 ... 0.5800 0.8780 0.4301\n", "[torch.FloatTensor of size 1x5x200]" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "h_t" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这里的 h_t 是网络最后的隐藏状态,网络也输出了 6 个结果" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 自己定义初始的隐藏状态\n", "h_0 = Variable(torch.randn(1, 5, 200))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这里的隐藏状态的大小有三个维度,分别是 (num_layers * num_direction, batch, hidden_size)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "out, h_t = rnn_seq(x, h_0)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Variable containing:\n", "( 0 ,.,.) = \n", " 0.2091 0.0353 0.0625 ... 0.2340 0.3734 -0.1307\n", " 0.5498 0.4221 0.7877 ... -0.4143 -0.1209 0.3335\n", " 0.0757 0.4204 0.3826 ... 0.3187 -0.4626 -0.2336\n", " 0.3106 0.7355 0.6436 ... 0.6611 0.2587 -0.0338\n", " 0.1025 0.6350 0.1943 ... 0.5720 0.8749 0.4525\n", "[torch.FloatTensor of size 1x5x200]" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "h_t" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([6, 5, 200])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "同时输出的结果也是 (seq, batch, feature)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "一般情况下我们都是用 `nn.RNN()` 而不是 `nn.RNNCell()`,因为 `nn.RNN()` 能够避免我们手动写循环,非常方便,同时如果不特别说明,我们也会选择使用默认的全 0 初始化隐藏状态" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LSTM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://ws1.sinaimg.cn/large/006tKfTcly1fmt9qj3uhmj30iz07ct90.jpg)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "LSTM 和基本的 RNN 是一样的,他的参数也是相同的,同时他也有 `nn.LSTMCell()` 和 `nn.LSTM()` 两种形式,跟前面讲的都是相同的,我们就不再赘述了,下面直接举个小例子" ] }, { "cell_type": "code", "execution_count": 58, "metadata": { "collapsed": true }, "outputs": [], "source": [ "lstm_seq = nn.LSTM(50, 100, num_layers=2) # 输入维度 100,输出 200,两层" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Parameter containing:\n", "1.00000e-02 *\n", " 3.8420 5.7387 6.1351 ... 1.2680 0.9890 1.3037\n", "-4.2301 6.8294 -4.8627 ... -6.4147 4.3015 8.4103\n", " 9.4411 5.0195 9.8620 ... -1.6096 9.2516 -0.6941\n", " ... ⋱ ... \n", " 1.2930 -1.3300 -0.9311 ... -6.0891 -0.7164 3.9578\n", " 9.0435 2.4674 9.4107 ... -3.3822 -3.9773 -3.0685\n", "-4.2039 -8.2992 -3.3605 ... 2.2875 8.2163 -9.3277\n", "[torch.FloatTensor of size 400x100]" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lstm_seq.weight_hh_l0 # 第一层的 h_t 权重" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**小练习:想想为什么这个系数的大小是 (400, 100)**" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "lstm_input = Variable(torch.randn(10, 3, 50)) # 序列 10,batch 是 3,输入维度 50" ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "collapsed": true }, "outputs": [], "source": [ "out, (h, c) = lstm_seq(lstm_input) # 使用默认的全 0 隐藏状态" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "注意这里 LSTM 输出的隐藏状态有两个,h 和 c,就是上图中的每个 cell 之间的两个箭头,这两个隐藏状态的大小都是相同的,(num_layers * direction, batch, feature)" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3, 100])" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "h.shape # 两层,Batch 是 3,特征是 100" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3, 100])" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c.shape" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([10, 3, 100])" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们可以不使用默认的隐藏状态,这是需要传入两个张量" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "collapsed": true }, "outputs": [], "source": [ "h_init = Variable(torch.randn(2, 3, 100))\n", "c_init = Variable(torch.randn(2, 3, 100))" ] }, { "cell_type": "code", "execution_count": 69, "metadata": { "collapsed": true }, "outputs": [], "source": [ "out, (h, c) = lstm_seq(lstm_input, (h_init, c_init))" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3, 100])" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "h.shape" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3, 100])" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c.shape" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([10, 3, 100])" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# GRU\n", "![](https://ws3.sinaimg.cn/large/006tKfTcly1fmtaj38y9sj30io06bmxc.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "GRU 和前面讲的这两个是同样的道理,就不再细说,还是演示一下例子" ] }, { "cell_type": "code", "execution_count": 73, "metadata": { "collapsed": true }, "outputs": [], "source": [ "gru_seq = nn.GRU(10, 20)\n", "gru_input = Variable(torch.randn(3, 32, 10))\n", "\n", "out, h = gru_seq(gru_input)" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Parameter containing:\n", " 0.0766 -0.0548 -0.2008 ... -0.0250 -0.1819 0.1453\n", "-0.1676 0.1622 0.0417 ... 0.1905 -0.0071 -0.1038\n", " 0.0444 -0.1516 0.2194 ... -0.0009 0.0771 0.0476\n", " ... ⋱ ... \n", " 0.1698 -0.1707 0.0340 ... -0.1315 0.1278 0.0946\n", " 0.1936 0.1369 -0.0694 ... -0.0667 0.0429 0.1322\n", " 0.0870 -0.1884 0.1732 ... -0.1423 -0.1723 0.2147\n", "[torch.FloatTensor of size 60x20]" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gru_seq.weight_hh_l0" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 32, 20])" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "h.shape" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 32, 20])" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out.shape" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }