You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pytorch-rnn.ipynb 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# PyTorch 中的循环神经网络模块\n",
  8. "前面我们讲了循环神经网络的基础知识和网络结构,下面我们教大家如何在 pytorch 下构建循环神经网络,因为 pytorch 的动态图机制,使得循环神经网络非常方便。"
  9. ]
  10. },
  11. {
  12. "cell_type": "markdown",
  13. "metadata": {},
  14. "source": [
  15. "## 一般的 RNN\n",
  16. "\n",
  17. "![](https://ws1.sinaimg.cn/large/006tKfTcly1fmt9xz889xj30kb07nglo.jpg)\n",
  18. "\n",
  19. "对于最简单的 RNN,我们可以使用下面两种方式去调用,分别是 `torch.nn.RNNCell()` 和 `torch.nn.RNN()`,这两种方式的区别在于 `RNNCell()` 只能接受序列中单步的输入,且必须传入隐藏状态,而 `RNN()` 可以接受一个序列的输入,默认会传入全 0 的隐藏状态,也可以自己申明隐藏状态传入。\n",
  20. "\n",
  21. "`RNN()` 里面的参数有\n",
  22. "\n",
  23. "input_size 表示输入 $x_t$ 的特征维度\n",
  24. "\n",
  25. "hidden_size 表示输出的特征维度\n",
  26. "\n",
  27. "num_layers 表示网络的层数\n",
  28. "\n",
  29. "nonlinearity 表示选用的非线性激活函数,默认是 'tanh'\n",
  30. "\n",
  31. "bias 表示是否使用偏置,默认使用\n",
  32. "\n",
  33. "batch_first 表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位\n",
  34. "\n",
  35. "dropout 表示是否在输出层应用 dropout\n",
  36. "\n",
  37. "bidirectional 表示是否使用双向的 rnn,默认是 False\n",
  38. "\n",
  39. "对于 `RNNCell()`,里面的参数就少很多,只有 input_size,hidden_size,bias 以及 nonlinearity"
  40. ]
  41. },
  42. {
  43. "cell_type": "code",
  44. "execution_count": 46,
  45. "metadata": {
  46. "collapsed": true
  47. },
  48. "outputs": [],
  49. "source": [
  50. "import torch\n",
  51. "from torch.autograd import Variable\n",
  52. "from torch import nn"
  53. ]
  54. },
  55. {
  56. "cell_type": "code",
  57. "execution_count": 47,
  58. "metadata": {
  59. "collapsed": true
  60. },
  61. "outputs": [],
  62. "source": [
  63. "# 定义一个单步的 rnn\n",
  64. "rnn_single = nn.RNNCell(input_size=100, hidden_size=200)"
  65. ]
  66. },
  67. {
  68. "cell_type": "code",
  69. "execution_count": 48,
  70. "metadata": {
  71. "collapsed": false
  72. },
  73. "outputs": [
  74. {
  75. "data": {
  76. "text/plain": [
  77. "Parameter containing:\n",
  78. "1.00000e-02 *\n",
  79. " 6.2260 -5.3805 3.5870 ... -2.2162 6.2760 1.6760\n",
  80. "-5.1878 -4.6751 -5.5926 ... -1.8942 0.1589 1.0725\n",
  81. " 3.3236 -3.2726 5.5399 ... 3.3193 0.2117 1.1730\n",
  82. " ... ⋱ ... \n",
  83. " 2.4032 -3.4415 5.1036 ... -2.2035 -0.1900 -6.4016\n",
  84. " 5.2031 -1.5793 -0.0623 ... 0.3424 6.9412 6.3707\n",
  85. "-5.4495 4.5280 2.1774 ... 1.8767 2.4968 5.3403\n",
  86. "[torch.FloatTensor of size 200x200]"
  87. ]
  88. },
  89. "execution_count": 48,
  90. "metadata": {},
  91. "output_type": "execute_result"
  92. }
  93. ],
  94. "source": [
  95. "# 访问其中的参数\n",
  96. "rnn_single.weight_hh"
  97. ]
  98. },
  99. {
  100. "cell_type": "code",
  101. "execution_count": 49,
  102. "metadata": {
  103. "collapsed": true
  104. },
  105. "outputs": [],
  106. "source": [
  107. "# 构造一个序列,长为 6,batch 是 5, 特征是 100\n",
  108. "x = Variable(torch.randn(6, 5, 100)) # 这是 rnn 的输入格式"
  109. ]
  110. },
  111. {
  112. "cell_type": "code",
  113. "execution_count": 50,
  114. "metadata": {
  115. "collapsed": true
  116. },
  117. "outputs": [],
  118. "source": [
  119. "# 定义初始的记忆状态\n",
  120. "h_t = Variable(torch.zeros(5, 200))"
  121. ]
  122. },
  123. {
  124. "cell_type": "code",
  125. "execution_count": 51,
  126. "metadata": {
  127. "collapsed": true
  128. },
  129. "outputs": [],
  130. "source": [
  131. "# 传入 rnn\n",
  132. "out = []\n",
  133. "for i in range(6): # 通过循环 6 次作用在整个序列上\n",
  134. " h_t = rnn_single(x[i], h_t)\n",
  135. " out.append(h_t)"
  136. ]
  137. },
  138. {
  139. "cell_type": "code",
  140. "execution_count": 52,
  141. "metadata": {
  142. "collapsed": false
  143. },
  144. "outputs": [
  145. {
  146. "data": {
  147. "text/plain": [
  148. "Variable containing:\n",
  149. " 0.0136 0.3723 0.1704 ... 0.4306 -0.7909 -0.5306\n",
  150. "-0.2681 -0.6261 -0.3926 ... 0.1752 0.5739 -0.2061\n",
  151. "-0.4918 -0.7611 0.2787 ... 0.0854 -0.3899 0.0092\n",
  152. " 0.6050 0.1852 -0.4261 ... -0.7220 0.6809 0.1825\n",
  153. "-0.6851 0.7273 0.5396 ... -0.7969 0.6133 -0.0852\n",
  154. "[torch.FloatTensor of size 5x200]"
  155. ]
  156. },
  157. "execution_count": 52,
  158. "metadata": {},
  159. "output_type": "execute_result"
  160. }
  161. ],
  162. "source": [
  163. "h_t"
  164. ]
  165. },
  166. {
  167. "cell_type": "code",
  168. "execution_count": 54,
  169. "metadata": {
  170. "collapsed": false
  171. },
  172. "outputs": [
  173. {
  174. "data": {
  175. "text/plain": [
  176. "6"
  177. ]
  178. },
  179. "execution_count": 54,
  180. "metadata": {},
  181. "output_type": "execute_result"
  182. }
  183. ],
  184. "source": [
  185. "len(out)"
  186. ]
  187. },
  188. {
  189. "cell_type": "code",
  190. "execution_count": 55,
  191. "metadata": {
  192. "collapsed": false
  193. },
  194. "outputs": [
  195. {
  196. "data": {
  197. "text/plain": [
  198. "torch.Size([5, 200])"
  199. ]
  200. },
  201. "execution_count": 55,
  202. "metadata": {},
  203. "output_type": "execute_result"
  204. }
  205. ],
  206. "source": [
  207. "out[0].shape # 每个输出的维度"
  208. ]
  209. },
  210. {
  211. "cell_type": "markdown",
  212. "metadata": {},
  213. "source": [
  214. "可以看到经过了 rnn 之后,隐藏状态的值已经被改变了,因为网络记忆了序列中的信息,同时输出 6 个结果"
  215. ]
  216. },
  217. {
  218. "cell_type": "markdown",
  219. "metadata": {},
  220. "source": [
  221. "下面我们看看直接使用 `RNN` 的情况"
  222. ]
  223. },
  224. {
  225. "cell_type": "code",
  226. "execution_count": 32,
  227. "metadata": {
  228. "collapsed": true
  229. },
  230. "outputs": [],
  231. "source": [
  232. "rnn_seq = nn.RNN(100, 200)"
  233. ]
  234. },
  235. {
  236. "cell_type": "code",
  237. "execution_count": 33,
  238. "metadata": {
  239. "collapsed": false
  240. },
  241. "outputs": [
  242. {
  243. "data": {
  244. "text/plain": [
  245. "Parameter containing:\n",
  246. "1.00000e-02 *\n",
  247. " 1.0998 -1.5018 -1.4337 ... 3.8385 -0.8958 -1.6781\n",
  248. " 5.3302 -5.4654 5.5568 ... 4.7399 5.4110 3.6170\n",
  249. " 1.0788 -0.6620 5.7689 ... -5.0747 -2.9066 0.6152\n",
  250. " ... ⋱ ... \n",
  251. "-5.6921 0.1843 -0.0803 ... -4.5852 5.6194 -1.4734\n",
  252. " 4.4306 6.9795 -1.5736 ... 3.4236 -0.3441 3.1397\n",
  253. " 7.0349 -1.6120 -4.2840 ... -5.5676 6.8897 6.1968\n",
  254. "[torch.FloatTensor of size 200x200]"
  255. ]
  256. },
  257. "execution_count": 33,
  258. "metadata": {},
  259. "output_type": "execute_result"
  260. }
  261. ],
  262. "source": [
  263. "# 访问其中的参数\n",
  264. "rnn_seq.weight_hh_l0"
  265. ]
  266. },
  267. {
  268. "cell_type": "code",
  269. "execution_count": 34,
  270. "metadata": {
  271. "collapsed": true
  272. },
  273. "outputs": [],
  274. "source": [
  275. "out, h_t = rnn_seq(x) # 使用默认的全 0 隐藏状态"
  276. ]
  277. },
  278. {
  279. "cell_type": "code",
  280. "execution_count": 36,
  281. "metadata": {
  282. "collapsed": false
  283. },
  284. "outputs": [
  285. {
  286. "data": {
  287. "text/plain": [
  288. "Variable containing:\n",
  289. "( 0 ,.,.) = \n",
  290. " 0.2012 0.0517 0.0570 ... 0.2316 0.3615 -0.1247\n",
  291. " 0.5307 0.4147 0.7881 ... -0.4138 -0.1444 0.3602\n",
  292. " 0.0882 0.4307 0.3939 ... 0.3244 -0.4629 -0.2315\n",
  293. " 0.2868 0.7400 0.6534 ... 0.6631 0.2624 -0.0162\n",
  294. " 0.0841 0.6274 0.1840 ... 0.5800 0.8780 0.4301\n",
  295. "[torch.FloatTensor of size 1x5x200]"
  296. ]
  297. },
  298. "execution_count": 36,
  299. "metadata": {},
  300. "output_type": "execute_result"
  301. }
  302. ],
  303. "source": [
  304. "h_t"
  305. ]
  306. },
  307. {
  308. "cell_type": "code",
  309. "execution_count": 35,
  310. "metadata": {
  311. "collapsed": false
  312. },
  313. "outputs": [
  314. {
  315. "data": {
  316. "text/plain": [
  317. "6"
  318. ]
  319. },
  320. "execution_count": 35,
  321. "metadata": {},
  322. "output_type": "execute_result"
  323. }
  324. ],
  325. "source": [
  326. "len(out)"
  327. ]
  328. },
  329. {
  330. "cell_type": "markdown",
  331. "metadata": {},
  332. "source": [
  333. "这里的 h_t 是网络最后的隐藏状态,网络也输出了 6 个结果"
  334. ]
  335. },
  336. {
  337. "cell_type": "code",
  338. "execution_count": 40,
  339. "metadata": {
  340. "collapsed": true
  341. },
  342. "outputs": [],
  343. "source": [
  344. "# 自己定义初始的隐藏状态\n",
  345. "h_0 = Variable(torch.randn(1, 5, 200))"
  346. ]
  347. },
  348. {
  349. "cell_type": "markdown",
  350. "metadata": {},
  351. "source": [
  352. "这里的隐藏状态的大小有三个维度,分别是 (num_layers * num_direction, batch, hidden_size)"
  353. ]
  354. },
  355. {
  356. "cell_type": "code",
  357. "execution_count": 41,
  358. "metadata": {
  359. "collapsed": false
  360. },
  361. "outputs": [],
  362. "source": [
  363. "out, h_t = rnn_seq(x, h_0)"
  364. ]
  365. },
  366. {
  367. "cell_type": "code",
  368. "execution_count": 42,
  369. "metadata": {
  370. "collapsed": false
  371. },
  372. "outputs": [
  373. {
  374. "data": {
  375. "text/plain": [
  376. "Variable containing:\n",
  377. "( 0 ,.,.) = \n",
  378. " 0.2091 0.0353 0.0625 ... 0.2340 0.3734 -0.1307\n",
  379. " 0.5498 0.4221 0.7877 ... -0.4143 -0.1209 0.3335\n",
  380. " 0.0757 0.4204 0.3826 ... 0.3187 -0.4626 -0.2336\n",
  381. " 0.3106 0.7355 0.6436 ... 0.6611 0.2587 -0.0338\n",
  382. " 0.1025 0.6350 0.1943 ... 0.5720 0.8749 0.4525\n",
  383. "[torch.FloatTensor of size 1x5x200]"
  384. ]
  385. },
  386. "execution_count": 42,
  387. "metadata": {},
  388. "output_type": "execute_result"
  389. }
  390. ],
  391. "source": [
  392. "h_t"
  393. ]
  394. },
  395. {
  396. "cell_type": "code",
  397. "execution_count": 45,
  398. "metadata": {
  399. "collapsed": false
  400. },
  401. "outputs": [
  402. {
  403. "data": {
  404. "text/plain": [
  405. "torch.Size([6, 5, 200])"
  406. ]
  407. },
  408. "execution_count": 45,
  409. "metadata": {},
  410. "output_type": "execute_result"
  411. }
  412. ],
  413. "source": [
  414. "out.shape"
  415. ]
  416. },
  417. {
  418. "cell_type": "markdown",
  419. "metadata": {},
  420. "source": [
  421. "同时输出的结果也是 (seq, batch, feature)"
  422. ]
  423. },
  424. {
  425. "cell_type": "markdown",
  426. "metadata": {},
  427. "source": [
  428. "一般情况下我们都是用 `nn.RNN()` 而不是 `nn.RNNCell()`,因为 `nn.RNN()` 能够避免我们手动写循环,非常方便,同时如果不特别说明,我们也会选择使用默认的全 0 初始化隐藏状态"
  429. ]
  430. },
  431. {
  432. "cell_type": "markdown",
  433. "metadata": {},
  434. "source": [
  435. "## LSTM"
  436. ]
  437. },
  438. {
  439. "cell_type": "markdown",
  440. "metadata": {},
  441. "source": [
  442. "![](https://ws1.sinaimg.cn/large/006tKfTcly1fmt9qj3uhmj30iz07ct90.jpg)"
  443. ]
  444. },
  445. {
  446. "cell_type": "markdown",
  447. "metadata": {
  448. "collapsed": true
  449. },
  450. "source": [
  451. "LSTM 和基本的 RNN 是一样的,他的参数也是相同的,同时他也有 `nn.LSTMCell()` 和 `nn.LSTM()` 两种形式,跟前面讲的都是相同的,我们就不再赘述了,下面直接举个小例子"
  452. ]
  453. },
  454. {
  455. "cell_type": "code",
  456. "execution_count": 58,
  457. "metadata": {
  458. "collapsed": true
  459. },
  460. "outputs": [],
  461. "source": [
  462. "lstm_seq = nn.LSTM(50, 100, num_layers=2) # 输入维度 100,输出 200,两层"
  463. ]
  464. },
  465. {
  466. "cell_type": "code",
  467. "execution_count": 80,
  468. "metadata": {
  469. "collapsed": false
  470. },
  471. "outputs": [
  472. {
  473. "data": {
  474. "text/plain": [
  475. "Parameter containing:\n",
  476. "1.00000e-02 *\n",
  477. " 3.8420 5.7387 6.1351 ... 1.2680 0.9890 1.3037\n",
  478. "-4.2301 6.8294 -4.8627 ... -6.4147 4.3015 8.4103\n",
  479. " 9.4411 5.0195 9.8620 ... -1.6096 9.2516 -0.6941\n",
  480. " ... ⋱ ... \n",
  481. " 1.2930 -1.3300 -0.9311 ... -6.0891 -0.7164 3.9578\n",
  482. " 9.0435 2.4674 9.4107 ... -3.3822 -3.9773 -3.0685\n",
  483. "-4.2039 -8.2992 -3.3605 ... 2.2875 8.2163 -9.3277\n",
  484. "[torch.FloatTensor of size 400x100]"
  485. ]
  486. },
  487. "execution_count": 80,
  488. "metadata": {},
  489. "output_type": "execute_result"
  490. }
  491. ],
  492. "source": [
  493. "lstm_seq.weight_hh_l0 # 第一层的 h_t 权重"
  494. ]
  495. },
  496. {
  497. "cell_type": "markdown",
  498. "metadata": {},
  499. "source": [
  500. "**小练习:想想为什么这个系数的大小是 (400, 100)**"
  501. ]
  502. },
  503. {
  504. "cell_type": "code",
  505. "execution_count": 59,
  506. "metadata": {
  507. "collapsed": false
  508. },
  509. "outputs": [],
  510. "source": [
  511. "lstm_input = Variable(torch.randn(10, 3, 50)) # 序列 10,batch 是 3,输入维度 50"
  512. ]
  513. },
  514. {
  515. "cell_type": "code",
  516. "execution_count": 64,
  517. "metadata": {
  518. "collapsed": true
  519. },
  520. "outputs": [],
  521. "source": [
  522. "out, (h, c) = lstm_seq(lstm_input) # 使用默认的全 0 隐藏状态"
  523. ]
  524. },
  525. {
  526. "cell_type": "markdown",
  527. "metadata": {},
  528. "source": [
  529. "注意这里 LSTM 输出的隐藏状态有两个,h 和 c,就是上图中的每个 cell 之间的两个箭头,这两个隐藏状态的大小都是相同的,(num_layers * direction, batch, feature)"
  530. ]
  531. },
  532. {
  533. "cell_type": "code",
  534. "execution_count": 66,
  535. "metadata": {
  536. "collapsed": false
  537. },
  538. "outputs": [
  539. {
  540. "data": {
  541. "text/plain": [
  542. "torch.Size([2, 3, 100])"
  543. ]
  544. },
  545. "execution_count": 66,
  546. "metadata": {},
  547. "output_type": "execute_result"
  548. }
  549. ],
  550. "source": [
  551. "h.shape # 两层,Batch 是 3,特征是 100"
  552. ]
  553. },
  554. {
  555. "cell_type": "code",
  556. "execution_count": 67,
  557. "metadata": {
  558. "collapsed": false
  559. },
  560. "outputs": [
  561. {
  562. "data": {
  563. "text/plain": [
  564. "torch.Size([2, 3, 100])"
  565. ]
  566. },
  567. "execution_count": 67,
  568. "metadata": {},
  569. "output_type": "execute_result"
  570. }
  571. ],
  572. "source": [
  573. "c.shape"
  574. ]
  575. },
  576. {
  577. "cell_type": "code",
  578. "execution_count": 61,
  579. "metadata": {
  580. "collapsed": false
  581. },
  582. "outputs": [
  583. {
  584. "data": {
  585. "text/plain": [
  586. "torch.Size([10, 3, 100])"
  587. ]
  588. },
  589. "execution_count": 61,
  590. "metadata": {},
  591. "output_type": "execute_result"
  592. }
  593. ],
  594. "source": [
  595. "out.shape"
  596. ]
  597. },
  598. {
  599. "cell_type": "markdown",
  600. "metadata": {},
  601. "source": [
  602. "我们可以不使用默认的隐藏状态,这是需要传入两个张量"
  603. ]
  604. },
  605. {
  606. "cell_type": "code",
  607. "execution_count": 68,
  608. "metadata": {
  609. "collapsed": true
  610. },
  611. "outputs": [],
  612. "source": [
  613. "h_init = Variable(torch.randn(2, 3, 100))\n",
  614. "c_init = Variable(torch.randn(2, 3, 100))"
  615. ]
  616. },
  617. {
  618. "cell_type": "code",
  619. "execution_count": 69,
  620. "metadata": {
  621. "collapsed": true
  622. },
  623. "outputs": [],
  624. "source": [
  625. "out, (h, c) = lstm_seq(lstm_input, (h_init, c_init))"
  626. ]
  627. },
  628. {
  629. "cell_type": "code",
  630. "execution_count": 70,
  631. "metadata": {
  632. "collapsed": false
  633. },
  634. "outputs": [
  635. {
  636. "data": {
  637. "text/plain": [
  638. "torch.Size([2, 3, 100])"
  639. ]
  640. },
  641. "execution_count": 70,
  642. "metadata": {},
  643. "output_type": "execute_result"
  644. }
  645. ],
  646. "source": [
  647. "h.shape"
  648. ]
  649. },
  650. {
  651. "cell_type": "code",
  652. "execution_count": 71,
  653. "metadata": {
  654. "collapsed": false
  655. },
  656. "outputs": [
  657. {
  658. "data": {
  659. "text/plain": [
  660. "torch.Size([2, 3, 100])"
  661. ]
  662. },
  663. "execution_count": 71,
  664. "metadata": {},
  665. "output_type": "execute_result"
  666. }
  667. ],
  668. "source": [
  669. "c.shape"
  670. ]
  671. },
  672. {
  673. "cell_type": "code",
  674. "execution_count": 72,
  675. "metadata": {
  676. "collapsed": false
  677. },
  678. "outputs": [
  679. {
  680. "data": {
  681. "text/plain": [
  682. "torch.Size([10, 3, 100])"
  683. ]
  684. },
  685. "execution_count": 72,
  686. "metadata": {},
  687. "output_type": "execute_result"
  688. }
  689. ],
  690. "source": [
  691. "out.shape"
  692. ]
  693. },
  694. {
  695. "cell_type": "markdown",
  696. "metadata": {},
  697. "source": [
  698. "# GRU\n",
  699. "![](https://ws3.sinaimg.cn/large/006tKfTcly1fmtaj38y9sj30io06bmxc.jpg)"
  700. ]
  701. },
  702. {
  703. "cell_type": "markdown",
  704. "metadata": {},
  705. "source": [
  706. "GRU 和前面讲的这两个是同样的道理,就不再细说,还是演示一下例子"
  707. ]
  708. },
  709. {
  710. "cell_type": "code",
  711. "execution_count": 73,
  712. "metadata": {
  713. "collapsed": true
  714. },
  715. "outputs": [],
  716. "source": [
  717. "gru_seq = nn.GRU(10, 20)\n",
  718. "gru_input = Variable(torch.randn(3, 32, 10))\n",
  719. "\n",
  720. "out, h = gru_seq(gru_input)"
  721. ]
  722. },
  723. {
  724. "cell_type": "code",
  725. "execution_count": 76,
  726. "metadata": {
  727. "collapsed": false
  728. },
  729. "outputs": [
  730. {
  731. "data": {
  732. "text/plain": [
  733. "Parameter containing:\n",
  734. " 0.0766 -0.0548 -0.2008 ... -0.0250 -0.1819 0.1453\n",
  735. "-0.1676 0.1622 0.0417 ... 0.1905 -0.0071 -0.1038\n",
  736. " 0.0444 -0.1516 0.2194 ... -0.0009 0.0771 0.0476\n",
  737. " ... ⋱ ... \n",
  738. " 0.1698 -0.1707 0.0340 ... -0.1315 0.1278 0.0946\n",
  739. " 0.1936 0.1369 -0.0694 ... -0.0667 0.0429 0.1322\n",
  740. " 0.0870 -0.1884 0.1732 ... -0.1423 -0.1723 0.2147\n",
  741. "[torch.FloatTensor of size 60x20]"
  742. ]
  743. },
  744. "execution_count": 76,
  745. "metadata": {},
  746. "output_type": "execute_result"
  747. }
  748. ],
  749. "source": [
  750. "gru_seq.weight_hh_l0"
  751. ]
  752. },
  753. {
  754. "cell_type": "code",
  755. "execution_count": 75,
  756. "metadata": {
  757. "collapsed": false
  758. },
  759. "outputs": [
  760. {
  761. "data": {
  762. "text/plain": [
  763. "torch.Size([1, 32, 20])"
  764. ]
  765. },
  766. "execution_count": 75,
  767. "metadata": {},
  768. "output_type": "execute_result"
  769. }
  770. ],
  771. "source": [
  772. "h.shape"
  773. ]
  774. },
  775. {
  776. "cell_type": "code",
  777. "execution_count": 74,
  778. "metadata": {
  779. "collapsed": false
  780. },
  781. "outputs": [
  782. {
  783. "data": {
  784. "text/plain": [
  785. "torch.Size([3, 32, 20])"
  786. ]
  787. },
  788. "execution_count": 74,
  789. "metadata": {},
  790. "output_type": "execute_result"
  791. }
  792. ],
  793. "source": [
  794. "out.shape"
  795. ]
  796. }
  797. ],
  798. "metadata": {
  799. "kernelspec": {
  800. "display_name": "mx",
  801. "language": "python",
  802. "name": "mx"
  803. },
  804. "language_info": {
  805. "codemirror_mode": {
  806. "name": "ipython",
  807. "version": 3
  808. },
  809. "file_extension": ".py",
  810. "mimetype": "text/x-python",
  811. "name": "python",
  812. "nbconvert_exporter": "python",
  813. "pygments_lexer": "ipython3",
  814. "version": "3.6.0"
  815. }
  816. },
  817. "nbformat": 4,
  818. "nbformat_minor": 2
  819. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。