You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pytorch-rnn.ipynb 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# PyTorch 中的循环神经网络模块\n",
  8. "前面我们讲了循环神经网络的基础知识和网络结构,下面我们教大家如何在 pytorch 下构建循环神经网络,因为 pytorch 的动态图机制,使得循环神经网络非常方便。"
  9. ]
  10. },
  11. {
  12. "cell_type": "markdown",
  13. "metadata": {},
  14. "source": [
  15. "## 一般的 RNN\n",
  16. "\n",
  17. "![](https://ws1.sinaimg.cn/large/006tKfTcly1fmt9xz889xj30kb07nglo.jpg)\n",
  18. "\n",
  19. "对于最简单的 RNN,我们可以使用下面两种方式去调用,分别是 `torch.nn.RNNCell()` 和 `torch.nn.RNN()`,这两种方式的区别在于 `RNNCell()` 只能接受序列中单步的输入,且必须传入隐藏状态,而 `RNN()` 可以接受一个序列的输入,默认会传入全 0 的隐藏状态,也可以自己申明隐藏状态传入。\n",
  20. "\n",
  21. "`RNN()` 里面的参数有\n",
  22. "\n",
  23. "input_size 表示输入 $x_t$ 的特征维度\n",
  24. "\n",
  25. "hidden_size 表示输出的特征维度\n",
  26. "\n",
  27. "num_layers 表示网络的层数\n",
  28. "\n",
  29. "nonlinearity 表示选用的非线性激活函数,默认是 'tanh'\n",
  30. "\n",
  31. "bias 表示是否使用偏置,默认使用\n",
  32. "\n",
  33. "batch_first 表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位\n",
  34. "\n",
  35. "dropout 表示是否在输出层应用 dropout\n",
  36. "\n",
  37. "bidirectional 表示是否使用双向的 rnn,默认是 False\n",
  38. "\n",
  39. "对于 `RNNCell()`,里面的参数就少很多,只有 input_size,hidden_size,bias 以及 nonlinearity"
  40. ]
  41. },
  42. {
  43. "cell_type": "code",
  44. "execution_count": 1,
  45. "metadata": {},
  46. "outputs": [],
  47. "source": [
  48. "import torch\n",
  49. "from torch.autograd import Variable\n",
  50. "from torch import nn"
  51. ]
  52. },
  53. {
  54. "cell_type": "code",
  55. "execution_count": 2,
  56. "metadata": {},
  57. "outputs": [],
  58. "source": [
  59. "# 定义一个单步的 rnn\n",
  60. "rnn_single = nn.RNNCell(input_size=100, hidden_size=200)"
  61. ]
  62. },
  63. {
  64. "cell_type": "code",
  65. "execution_count": 3,
  66. "metadata": {},
  67. "outputs": [
  68. {
  69. "data": {
  70. "text/plain": [
  71. "Parameter containing:\n",
  72. "tensor([[-2.7963e-02, 3.6102e-02, 5.6609e-03, ..., -3.0035e-02,\n",
  73. " 2.7740e-02, 2.3327e-02],\n",
  74. " [-2.8567e-02, -3.2150e-02, -2.6686e-02, ..., -4.6441e-02,\n",
  75. " 3.5804e-02, 9.7260e-05],\n",
  76. " [ 4.6686e-02, -1.5825e-02, 6.7149e-02, ..., 3.3435e-02,\n",
  77. " -2.7623e-02, -6.7693e-02],\n",
  78. " ...,\n",
  79. " [-2.0338e-02, -1.6551e-02, 5.8996e-02, ..., -4.0145e-02,\n",
  80. " -6.9111e-03, -3.2740e-02],\n",
  81. " [-2.4584e-02, 2.3591e-02, 8.3090e-03, ..., -3.6077e-02,\n",
  82. " -6.0432e-03, 5.6279e-02],\n",
  83. " [ 5.6955e-02, -5.1925e-02, 3.1950e-02, ..., -5.6692e-02,\n",
  84. " 6.1773e-02, 1.9715e-02]], requires_grad=True)"
  85. ]
  86. },
  87. "execution_count": 3,
  88. "metadata": {},
  89. "output_type": "execute_result"
  90. }
  91. ],
  92. "source": [
  93. "# 访问其中的参数\n",
  94. "rnn_single.weight_hh"
  95. ]
  96. },
  97. {
  98. "cell_type": "code",
  99. "execution_count": 4,
  100. "metadata": {},
  101. "outputs": [
  102. {
  103. "data": {
  104. "text/plain": [
  105. "tensor([[[ 1.4637, -2.0015, 0.6298, ..., -1.1210, -1.6310, 0.5122],\n",
  106. " [-0.1500, -0.6931, 0.1568, ..., -0.9185, -0.5088, -1.0746],\n",
  107. " [ 0.1717, 1.2186, -0.8093, ..., 0.8630, 0.4601, -1.0218],\n",
  108. " [-0.3034, 2.8634, 2.2470, ..., 0.1678, -2.0585, -0.9628],\n",
  109. " [-2.3764, -0.4235, -1.1760, ..., -1.2251, 0.6761, -1.0323]],\n",
  110. "\n",
  111. " [[-1.3497, -0.6778, -0.0528, ..., -0.1852, -0.3997, -0.7633],\n",
  112. " [ 1.0105, 0.7974, 0.4253, ..., -1.1167, -1.3870, -1.3583],\n",
  113. " [ 0.2785, 0.5013, -0.5881, ..., -0.0283, 0.6044, -0.3249],\n",
  114. " [-1.9298, -0.6575, -1.2878, ..., 0.5636, -0.3266, 1.9391],\n",
  115. " [ 1.3117, -1.1429, -1.5837, ..., -1.5248, -0.2046, 1.0696]],\n",
  116. "\n",
  117. " [[-0.8637, -1.0572, -0.2438, ..., 0.1011, -0.4630, 0.0526],\n",
  118. " [-0.0056, -0.9442, -0.5588, ..., -0.6881, -1.2189, -1.1846],\n",
  119. " [ 0.8341, 0.6924, -0.4376, ..., 1.1331, -0.9766, 1.3822],\n",
  120. " [-0.3815, -1.3457, 0.5320, ..., 0.8280, 0.2146, -0.8704],\n",
  121. " [-0.6424, 1.3608, -0.5325, ..., -0.3414, 1.0094, 1.2650]],\n",
  122. "\n",
  123. " [[-0.1776, -0.2037, -0.7093, ..., -1.1442, -1.0058, -0.6898],\n",
  124. " [ 0.2921, -1.9473, -0.6989, ..., 0.6852, -0.2225, -0.6484],\n",
  125. " [-0.8576, 1.9338, -1.5359, ..., -0.3545, -0.9438, 0.1476],\n",
  126. " [ 2.3669, 0.8673, 2.0521, ..., -0.4679, -0.4050, 0.7761],\n",
  127. " [ 0.3706, 1.2876, -0.5311, ..., 0.4794, -0.4209, 0.5343]],\n",
  128. "\n",
  129. " [[-0.2726, -1.2583, -0.8259, ..., 0.8811, 0.5900, 0.1770],\n",
  130. " [ 1.1066, -0.4899, 0.9143, ..., -2.2898, 0.1525, -2.2099],\n",
  131. " [-1.3824, 0.3142, 1.2140, ..., 0.5470, -0.4883, -0.3204],\n",
  132. " [ 1.8471, 0.6011, 0.0613, ..., 1.1584, -0.8014, 0.4891],\n",
  133. " [ 1.5201, -1.7853, 1.3107, ..., 0.0032, -1.3422, 0.7332]],\n",
  134. "\n",
  135. " [[ 0.3025, -0.7314, -0.2032, ..., -0.9658, -1.8131, 0.5922],\n",
  136. " [-0.0878, 0.0909, 0.7064, ..., 2.4186, -0.0863, 0.0930],\n",
  137. " [-1.4278, -1.0901, 1.6742, ..., 0.3020, -0.6106, -0.4299],\n",
  138. " [-1.8291, -1.1337, -0.2405, ..., -1.2000, 2.0510, 1.3617],\n",
  139. " [-2.7953, -0.0559, 1.0224, ..., 0.4400, 0.9099, -1.5845]]])"
  140. ]
  141. },
  142. "execution_count": 4,
  143. "metadata": {},
  144. "output_type": "execute_result"
  145. }
  146. ],
  147. "source": [
  148. "# 构造一个序列,长为 6,batch 是 5, 特征是 100\n",
  149. "x = Variable(torch.randn(6, 5, 100)) # 这是 rnn 的输入格式\n",
  150. "x"
  151. ]
  152. },
  153. {
  154. "cell_type": "code",
  155. "execution_count": 50,
  156. "metadata": {
  157. "collapsed": true
  158. },
  159. "outputs": [],
  160. "source": [
  161. "# 定义初始的记忆状态\n",
  162. "h_t = Variable(torch.zeros(5, 200))"
  163. ]
  164. },
  165. {
  166. "cell_type": "code",
  167. "execution_count": 51,
  168. "metadata": {
  169. "collapsed": true
  170. },
  171. "outputs": [],
  172. "source": [
  173. "# 传入 rnn\n",
  174. "out = []\n",
  175. "for i in range(6): # 通过循环 6 次作用在整个序列上\n",
  176. " h_t = rnn_single(x[i], h_t)\n",
  177. " out.append(h_t)"
  178. ]
  179. },
  180. {
  181. "cell_type": "code",
  182. "execution_count": 52,
  183. "metadata": {},
  184. "outputs": [
  185. {
  186. "data": {
  187. "text/plain": [
  188. "Variable containing:\n",
  189. " 0.0136 0.3723 0.1704 ... 0.4306 -0.7909 -0.5306\n",
  190. "-0.2681 -0.6261 -0.3926 ... 0.1752 0.5739 -0.2061\n",
  191. "-0.4918 -0.7611 0.2787 ... 0.0854 -0.3899 0.0092\n",
  192. " 0.6050 0.1852 -0.4261 ... -0.7220 0.6809 0.1825\n",
  193. "-0.6851 0.7273 0.5396 ... -0.7969 0.6133 -0.0852\n",
  194. "[torch.FloatTensor of size 5x200]"
  195. ]
  196. },
  197. "execution_count": 52,
  198. "metadata": {},
  199. "output_type": "execute_result"
  200. }
  201. ],
  202. "source": [
  203. "h_t"
  204. ]
  205. },
  206. {
  207. "cell_type": "code",
  208. "execution_count": 54,
  209. "metadata": {},
  210. "outputs": [
  211. {
  212. "data": {
  213. "text/plain": [
  214. "6"
  215. ]
  216. },
  217. "execution_count": 54,
  218. "metadata": {},
  219. "output_type": "execute_result"
  220. }
  221. ],
  222. "source": [
  223. "len(out)"
  224. ]
  225. },
  226. {
  227. "cell_type": "code",
  228. "execution_count": 55,
  229. "metadata": {},
  230. "outputs": [
  231. {
  232. "data": {
  233. "text/plain": [
  234. "torch.Size([5, 200])"
  235. ]
  236. },
  237. "execution_count": 55,
  238. "metadata": {},
  239. "output_type": "execute_result"
  240. }
  241. ],
  242. "source": [
  243. "out[0].shape # 每个输出的维度"
  244. ]
  245. },
  246. {
  247. "cell_type": "markdown",
  248. "metadata": {},
  249. "source": [
  250. "可以看到经过了 rnn 之后,隐藏状态的值已经被改变了,因为网络记忆了序列中的信息,同时输出 6 个结果"
  251. ]
  252. },
  253. {
  254. "cell_type": "markdown",
  255. "metadata": {},
  256. "source": [
  257. "下面我们看看直接使用 `RNN` 的情况"
  258. ]
  259. },
  260. {
  261. "cell_type": "code",
  262. "execution_count": 32,
  263. "metadata": {
  264. "collapsed": true
  265. },
  266. "outputs": [],
  267. "source": [
  268. "rnn_seq = nn.RNN(100, 200)"
  269. ]
  270. },
  271. {
  272. "cell_type": "code",
  273. "execution_count": 33,
  274. "metadata": {},
  275. "outputs": [
  276. {
  277. "data": {
  278. "text/plain": [
  279. "Parameter containing:\n",
  280. "1.00000e-02 *\n",
  281. " 1.0998 -1.5018 -1.4337 ... 3.8385 -0.8958 -1.6781\n",
  282. " 5.3302 -5.4654 5.5568 ... 4.7399 5.4110 3.6170\n",
  283. " 1.0788 -0.6620 5.7689 ... -5.0747 -2.9066 0.6152\n",
  284. " ... ⋱ ... \n",
  285. "-5.6921 0.1843 -0.0803 ... -4.5852 5.6194 -1.4734\n",
  286. " 4.4306 6.9795 -1.5736 ... 3.4236 -0.3441 3.1397\n",
  287. " 7.0349 -1.6120 -4.2840 ... -5.5676 6.8897 6.1968\n",
  288. "[torch.FloatTensor of size 200x200]"
  289. ]
  290. },
  291. "execution_count": 33,
  292. "metadata": {},
  293. "output_type": "execute_result"
  294. }
  295. ],
  296. "source": [
  297. "# 访问其中的参数\n",
  298. "rnn_seq.weight_hh_l0"
  299. ]
  300. },
  301. {
  302. "cell_type": "code",
  303. "execution_count": 34,
  304. "metadata": {
  305. "collapsed": true
  306. },
  307. "outputs": [],
  308. "source": [
  309. "out, h_t = rnn_seq(x) # 使用默认的全 0 隐藏状态"
  310. ]
  311. },
  312. {
  313. "cell_type": "code",
  314. "execution_count": 36,
  315. "metadata": {},
  316. "outputs": [
  317. {
  318. "data": {
  319. "text/plain": [
  320. "Variable containing:\n",
  321. "( 0 ,.,.) = \n",
  322. " 0.2012 0.0517 0.0570 ... 0.2316 0.3615 -0.1247\n",
  323. " 0.5307 0.4147 0.7881 ... -0.4138 -0.1444 0.3602\n",
  324. " 0.0882 0.4307 0.3939 ... 0.3244 -0.4629 -0.2315\n",
  325. " 0.2868 0.7400 0.6534 ... 0.6631 0.2624 -0.0162\n",
  326. " 0.0841 0.6274 0.1840 ... 0.5800 0.8780 0.4301\n",
  327. "[torch.FloatTensor of size 1x5x200]"
  328. ]
  329. },
  330. "execution_count": 36,
  331. "metadata": {},
  332. "output_type": "execute_result"
  333. }
  334. ],
  335. "source": [
  336. "h_t"
  337. ]
  338. },
  339. {
  340. "cell_type": "code",
  341. "execution_count": 35,
  342. "metadata": {},
  343. "outputs": [
  344. {
  345. "data": {
  346. "text/plain": [
  347. "6"
  348. ]
  349. },
  350. "execution_count": 35,
  351. "metadata": {},
  352. "output_type": "execute_result"
  353. }
  354. ],
  355. "source": [
  356. "len(out)"
  357. ]
  358. },
  359. {
  360. "cell_type": "markdown",
  361. "metadata": {},
  362. "source": [
  363. "这里的 h_t 是网络最后的隐藏状态,网络也输出了 6 个结果"
  364. ]
  365. },
  366. {
  367. "cell_type": "code",
  368. "execution_count": 40,
  369. "metadata": {
  370. "collapsed": true
  371. },
  372. "outputs": [],
  373. "source": [
  374. "# 自己定义初始的隐藏状态\n",
  375. "h_0 = Variable(torch.randn(1, 5, 200))"
  376. ]
  377. },
  378. {
  379. "cell_type": "markdown",
  380. "metadata": {},
  381. "source": [
  382. "这里的隐藏状态的大小有三个维度,分别是 (num_layers * num_direction, batch, hidden_size)"
  383. ]
  384. },
  385. {
  386. "cell_type": "code",
  387. "execution_count": 41,
  388. "metadata": {},
  389. "outputs": [],
  390. "source": [
  391. "out, h_t = rnn_seq(x, h_0)"
  392. ]
  393. },
  394. {
  395. "cell_type": "code",
  396. "execution_count": 42,
  397. "metadata": {},
  398. "outputs": [
  399. {
  400. "data": {
  401. "text/plain": [
  402. "Variable containing:\n",
  403. "( 0 ,.,.) = \n",
  404. " 0.2091 0.0353 0.0625 ... 0.2340 0.3734 -0.1307\n",
  405. " 0.5498 0.4221 0.7877 ... -0.4143 -0.1209 0.3335\n",
  406. " 0.0757 0.4204 0.3826 ... 0.3187 -0.4626 -0.2336\n",
  407. " 0.3106 0.7355 0.6436 ... 0.6611 0.2587 -0.0338\n",
  408. " 0.1025 0.6350 0.1943 ... 0.5720 0.8749 0.4525\n",
  409. "[torch.FloatTensor of size 1x5x200]"
  410. ]
  411. },
  412. "execution_count": 42,
  413. "metadata": {},
  414. "output_type": "execute_result"
  415. }
  416. ],
  417. "source": [
  418. "h_t"
  419. ]
  420. },
  421. {
  422. "cell_type": "code",
  423. "execution_count": 45,
  424. "metadata": {},
  425. "outputs": [
  426. {
  427. "data": {
  428. "text/plain": [
  429. "torch.Size([6, 5, 200])"
  430. ]
  431. },
  432. "execution_count": 45,
  433. "metadata": {},
  434. "output_type": "execute_result"
  435. }
  436. ],
  437. "source": [
  438. "out.shape"
  439. ]
  440. },
  441. {
  442. "cell_type": "markdown",
  443. "metadata": {},
  444. "source": [
  445. "同时输出的结果也是 (seq, batch, feature)"
  446. ]
  447. },
  448. {
  449. "cell_type": "markdown",
  450. "metadata": {},
  451. "source": [
  452. "一般情况下我们都是用 `nn.RNN()` 而不是 `nn.RNNCell()`,因为 `nn.RNN()` 能够避免我们手动写循环,非常方便,同时如果不特别说明,我们也会选择使用默认的全 0 初始化隐藏状态"
  453. ]
  454. },
  455. {
  456. "cell_type": "markdown",
  457. "metadata": {},
  458. "source": [
  459. "## LSTM"
  460. ]
  461. },
  462. {
  463. "cell_type": "markdown",
  464. "metadata": {},
  465. "source": [
  466. "![](https://ws1.sinaimg.cn/large/006tKfTcly1fmt9qj3uhmj30iz07ct90.jpg)"
  467. ]
  468. },
  469. {
  470. "cell_type": "markdown",
  471. "metadata": {
  472. "collapsed": true
  473. },
  474. "source": [
  475. "LSTM 和基本的 RNN 是一样的,他的参数也是相同的,同时他也有 `nn.LSTMCell()` 和 `nn.LSTM()` 两种形式,跟前面讲的都是相同的,我们就不再赘述了,下面直接举个小例子"
  476. ]
  477. },
  478. {
  479. "cell_type": "code",
  480. "execution_count": 58,
  481. "metadata": {
  482. "collapsed": true
  483. },
  484. "outputs": [],
  485. "source": [
  486. "lstm_seq = nn.LSTM(50, 100, num_layers=2) # 输入维度 100,输出 200,两层"
  487. ]
  488. },
  489. {
  490. "cell_type": "code",
  491. "execution_count": 80,
  492. "metadata": {},
  493. "outputs": [
  494. {
  495. "data": {
  496. "text/plain": [
  497. "Parameter containing:\n",
  498. "1.00000e-02 *\n",
  499. " 3.8420 5.7387 6.1351 ... 1.2680 0.9890 1.3037\n",
  500. "-4.2301 6.8294 -4.8627 ... -6.4147 4.3015 8.4103\n",
  501. " 9.4411 5.0195 9.8620 ... -1.6096 9.2516 -0.6941\n",
  502. " ... ⋱ ... \n",
  503. " 1.2930 -1.3300 -0.9311 ... -6.0891 -0.7164 3.9578\n",
  504. " 9.0435 2.4674 9.4107 ... -3.3822 -3.9773 -3.0685\n",
  505. "-4.2039 -8.2992 -3.3605 ... 2.2875 8.2163 -9.3277\n",
  506. "[torch.FloatTensor of size 400x100]"
  507. ]
  508. },
  509. "execution_count": 80,
  510. "metadata": {},
  511. "output_type": "execute_result"
  512. }
  513. ],
  514. "source": [
  515. "lstm_seq.weight_hh_l0 # 第一层的 h_t 权重"
  516. ]
  517. },
  518. {
  519. "cell_type": "markdown",
  520. "metadata": {},
  521. "source": [
  522. "**小练习:想想为什么这个系数的大小是 (400, 100)**"
  523. ]
  524. },
  525. {
  526. "cell_type": "code",
  527. "execution_count": 59,
  528. "metadata": {},
  529. "outputs": [],
  530. "source": [
  531. "lstm_input = Variable(torch.randn(10, 3, 50)) # 序列 10,batch 是 3,输入维度 50"
  532. ]
  533. },
  534. {
  535. "cell_type": "code",
  536. "execution_count": 64,
  537. "metadata": {
  538. "collapsed": true
  539. },
  540. "outputs": [],
  541. "source": [
  542. "out, (h, c) = lstm_seq(lstm_input) # 使用默认的全 0 隐藏状态"
  543. ]
  544. },
  545. {
  546. "cell_type": "markdown",
  547. "metadata": {},
  548. "source": [
  549. "注意这里 LSTM 输出的隐藏状态有两个,h 和 c,就是上图中的每个 cell 之间的两个箭头,这两个隐藏状态的大小都是相同的,(num_layers * direction, batch, feature)"
  550. ]
  551. },
  552. {
  553. "cell_type": "code",
  554. "execution_count": 66,
  555. "metadata": {},
  556. "outputs": [
  557. {
  558. "data": {
  559. "text/plain": [
  560. "torch.Size([2, 3, 100])"
  561. ]
  562. },
  563. "execution_count": 66,
  564. "metadata": {},
  565. "output_type": "execute_result"
  566. }
  567. ],
  568. "source": [
  569. "h.shape # 两层,Batch 是 3,特征是 100"
  570. ]
  571. },
  572. {
  573. "cell_type": "code",
  574. "execution_count": 67,
  575. "metadata": {},
  576. "outputs": [
  577. {
  578. "data": {
  579. "text/plain": [
  580. "torch.Size([2, 3, 100])"
  581. ]
  582. },
  583. "execution_count": 67,
  584. "metadata": {},
  585. "output_type": "execute_result"
  586. }
  587. ],
  588. "source": [
  589. "c.shape"
  590. ]
  591. },
  592. {
  593. "cell_type": "code",
  594. "execution_count": 61,
  595. "metadata": {},
  596. "outputs": [
  597. {
  598. "data": {
  599. "text/plain": [
  600. "torch.Size([10, 3, 100])"
  601. ]
  602. },
  603. "execution_count": 61,
  604. "metadata": {},
  605. "output_type": "execute_result"
  606. }
  607. ],
  608. "source": [
  609. "out.shape"
  610. ]
  611. },
  612. {
  613. "cell_type": "markdown",
  614. "metadata": {},
  615. "source": [
  616. "我们可以不使用默认的隐藏状态,这是需要传入两个张量"
  617. ]
  618. },
  619. {
  620. "cell_type": "code",
  621. "execution_count": 68,
  622. "metadata": {
  623. "collapsed": true
  624. },
  625. "outputs": [],
  626. "source": [
  627. "h_init = Variable(torch.randn(2, 3, 100))\n",
  628. "c_init = Variable(torch.randn(2, 3, 100))"
  629. ]
  630. },
  631. {
  632. "cell_type": "code",
  633. "execution_count": 69,
  634. "metadata": {
  635. "collapsed": true
  636. },
  637. "outputs": [],
  638. "source": [
  639. "out, (h, c) = lstm_seq(lstm_input, (h_init, c_init))"
  640. ]
  641. },
  642. {
  643. "cell_type": "code",
  644. "execution_count": 70,
  645. "metadata": {},
  646. "outputs": [
  647. {
  648. "data": {
  649. "text/plain": [
  650. "torch.Size([2, 3, 100])"
  651. ]
  652. },
  653. "execution_count": 70,
  654. "metadata": {},
  655. "output_type": "execute_result"
  656. }
  657. ],
  658. "source": [
  659. "h.shape"
  660. ]
  661. },
  662. {
  663. "cell_type": "code",
  664. "execution_count": 71,
  665. "metadata": {},
  666. "outputs": [
  667. {
  668. "data": {
  669. "text/plain": [
  670. "torch.Size([2, 3, 100])"
  671. ]
  672. },
  673. "execution_count": 71,
  674. "metadata": {},
  675. "output_type": "execute_result"
  676. }
  677. ],
  678. "source": [
  679. "c.shape"
  680. ]
  681. },
  682. {
  683. "cell_type": "code",
  684. "execution_count": 72,
  685. "metadata": {},
  686. "outputs": [
  687. {
  688. "data": {
  689. "text/plain": [
  690. "torch.Size([10, 3, 100])"
  691. ]
  692. },
  693. "execution_count": 72,
  694. "metadata": {},
  695. "output_type": "execute_result"
  696. }
  697. ],
  698. "source": [
  699. "out.shape"
  700. ]
  701. },
  702. {
  703. "cell_type": "markdown",
  704. "metadata": {},
  705. "source": [
  706. "# GRU\n",
  707. "![](https://ws3.sinaimg.cn/large/006tKfTcly1fmtaj38y9sj30io06bmxc.jpg)"
  708. ]
  709. },
  710. {
  711. "cell_type": "markdown",
  712. "metadata": {},
  713. "source": [
  714. "GRU 和前面讲的这两个是同样的道理,就不再细说,还是演示一下例子"
  715. ]
  716. },
  717. {
  718. "cell_type": "code",
  719. "execution_count": 73,
  720. "metadata": {
  721. "collapsed": true
  722. },
  723. "outputs": [],
  724. "source": [
  725. "gru_seq = nn.GRU(10, 20)\n",
  726. "gru_input = Variable(torch.randn(3, 32, 10))\n",
  727. "\n",
  728. "out, h = gru_seq(gru_input)"
  729. ]
  730. },
  731. {
  732. "cell_type": "code",
  733. "execution_count": 76,
  734. "metadata": {},
  735. "outputs": [
  736. {
  737. "data": {
  738. "text/plain": [
  739. "Parameter containing:\n",
  740. " 0.0766 -0.0548 -0.2008 ... -0.0250 -0.1819 0.1453\n",
  741. "-0.1676 0.1622 0.0417 ... 0.1905 -0.0071 -0.1038\n",
  742. " 0.0444 -0.1516 0.2194 ... -0.0009 0.0771 0.0476\n",
  743. " ... ⋱ ... \n",
  744. " 0.1698 -0.1707 0.0340 ... -0.1315 0.1278 0.0946\n",
  745. " 0.1936 0.1369 -0.0694 ... -0.0667 0.0429 0.1322\n",
  746. " 0.0870 -0.1884 0.1732 ... -0.1423 -0.1723 0.2147\n",
  747. "[torch.FloatTensor of size 60x20]"
  748. ]
  749. },
  750. "execution_count": 76,
  751. "metadata": {},
  752. "output_type": "execute_result"
  753. }
  754. ],
  755. "source": [
  756. "gru_seq.weight_hh_l0"
  757. ]
  758. },
  759. {
  760. "cell_type": "code",
  761. "execution_count": 75,
  762. "metadata": {},
  763. "outputs": [
  764. {
  765. "data": {
  766. "text/plain": [
  767. "torch.Size([1, 32, 20])"
  768. ]
  769. },
  770. "execution_count": 75,
  771. "metadata": {},
  772. "output_type": "execute_result"
  773. }
  774. ],
  775. "source": [
  776. "h.shape"
  777. ]
  778. },
  779. {
  780. "cell_type": "code",
  781. "execution_count": 74,
  782. "metadata": {},
  783. "outputs": [
  784. {
  785. "data": {
  786. "text/plain": [
  787. "torch.Size([3, 32, 20])"
  788. ]
  789. },
  790. "execution_count": 74,
  791. "metadata": {},
  792. "output_type": "execute_result"
  793. }
  794. ],
  795. "source": [
  796. "out.shape"
  797. ]
  798. }
  799. ],
  800. "metadata": {
  801. "kernelspec": {
  802. "display_name": "Python 3",
  803. "language": "python",
  804. "name": "python3"
  805. },
  806. "language_info": {
  807. "codemirror_mode": {
  808. "name": "ipython",
  809. "version": 3
  810. },
  811. "file_extension": ".py",
  812. "mimetype": "text/x-python",
  813. "name": "python",
  814. "nbconvert_exporter": "python",
  815. "pygments_lexer": "ipython3",
  816. "version": "3.6.8"
  817. }
  818. },
  819. "nbformat": 4,
  820. "nbformat_minor": 2
  821. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。