You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

seq-lstm.ipynb 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# LSTM 做词性预测\n",
  8. "前面我们讲了词嵌入以及 n-gram 模型做单词预测,但是目前还没有用到 RNN,在最后这一次课中,我们会结合前面讲的所有预备知识,教大家如何使用 LSTM 来做词性预测。\n",
  9. "\n",
  10. "## 模型介绍\n",
  11. "对于一个单词,会有这不同的词性,首先能够根据一个单词的后缀来初步判断,比如 -ly 这种后缀,很大概率是一个副词,除此之外,一个相同的单词可以表示两种不同的词性,比如 book 既可以表示名词,也可以表示动词,所以到底这个词是什么词性需要结合前后文来具体判断。\n",
  12. "\n",
  13. "根据这个问题,我们可以使用 lstm 模型来进行预测,首先对于一个单词,可以将其看作一个序列,比如 apple 是由 a p p l e 这 5 个单词构成,这就形成了 5 的序列,我们可以对这些字符构建词嵌入,然后输入 lstm,就像 lstm 做图像分类一样,只取最后一个输出作为预测结果,整个单词的字符串能够形成一种记忆的特性,帮助我们更好的预测词性。\n",
  14. "\n",
  15. "![](https://ws3.sinaimg.cn/large/006tKfTcgy1fmxi67w0f7j30ap05qq2u.jpg)\n",
  16. "\n",
  17. "接着我们把这个单词和其前面几个单词构成序列,可以对这些单词构建新的词嵌入,最后输出结果是单词的词性,也就是根据前面几个词的信息对这个词的词性进行分类。\n",
  18. "\n",
  19. "下面我们用例子来简单的说明"
  20. ]
  21. },
  22. {
  23. "cell_type": "code",
  24. "execution_count": 1,
  25. "metadata": {
  26. "collapsed": true
  27. },
  28. "outputs": [],
  29. "source": [
  30. "import torch\n",
  31. "from torch import nn\n",
  32. "from torch.autograd import Variable"
  33. ]
  34. },
  35. {
  36. "cell_type": "markdown",
  37. "metadata": {},
  38. "source": [
  39. "我们使用下面简单的训练集"
  40. ]
  41. },
  42. {
  43. "cell_type": "code",
  44. "execution_count": 2,
  45. "metadata": {
  46. "collapsed": true
  47. },
  48. "outputs": [],
  49. "source": [
  50. "training_data = [(\"The dog ate the apple\".split(),\n",
  51. " [\"DET\", \"NN\", \"V\", \"DET\", \"NN\"]),\n",
  52. " (\"Everybody read that book\".split(), \n",
  53. " [\"NN\", \"V\", \"DET\", \"NN\"])]"
  54. ]
  55. },
  56. {
  57. "cell_type": "markdown",
  58. "metadata": {},
  59. "source": [
  60. "接下来我们需要对单词和标签进行编码"
  61. ]
  62. },
  63. {
  64. "cell_type": "code",
  65. "execution_count": 3,
  66. "metadata": {
  67. "collapsed": true
  68. },
  69. "outputs": [],
  70. "source": [
  71. "word_to_idx = {}\n",
  72. "tag_to_idx = {}\n",
  73. "for context, tag in training_data:\n",
  74. " for word in context:\n",
  75. " if word.lower() not in word_to_idx:\n",
  76. " word_to_idx[word.lower()] = len(word_to_idx)\n",
  77. " for label in tag:\n",
  78. " if label.lower() not in tag_to_idx:\n",
  79. " tag_to_idx[label.lower()] = len(tag_to_idx)"
  80. ]
  81. },
  82. {
  83. "cell_type": "code",
  84. "execution_count": 4,
  85. "metadata": {},
  86. "outputs": [
  87. {
  88. "data": {
  89. "text/plain": [
  90. "{'apple': 3,\n",
  91. " 'ate': 2,\n",
  92. " 'book': 7,\n",
  93. " 'dog': 1,\n",
  94. " 'everybody': 4,\n",
  95. " 'read': 5,\n",
  96. " 'that': 6,\n",
  97. " 'the': 0}"
  98. ]
  99. },
  100. "execution_count": 4,
  101. "metadata": {},
  102. "output_type": "execute_result"
  103. }
  104. ],
  105. "source": [
  106. "word_to_idx"
  107. ]
  108. },
  109. {
  110. "cell_type": "code",
  111. "execution_count": 5,
  112. "metadata": {},
  113. "outputs": [
  114. {
  115. "data": {
  116. "text/plain": [
  117. "{'det': 0, 'nn': 1, 'v': 2}"
  118. ]
  119. },
  120. "execution_count": 5,
  121. "metadata": {},
  122. "output_type": "execute_result"
  123. }
  124. ],
  125. "source": [
  126. "tag_to_idx"
  127. ]
  128. },
  129. {
  130. "cell_type": "markdown",
  131. "metadata": {},
  132. "source": [
  133. "然后我们对字母进行编码"
  134. ]
  135. },
  136. {
  137. "cell_type": "code",
  138. "execution_count": 6,
  139. "metadata": {
  140. "collapsed": true
  141. },
  142. "outputs": [],
  143. "source": [
  144. "alphabet = 'abcdefghijklmnopqrstuvwxyz'\n",
  145. "char_to_idx = {}\n",
  146. "for i in range(len(alphabet)):\n",
  147. " char_to_idx[alphabet[i]] = i"
  148. ]
  149. },
  150. {
  151. "cell_type": "code",
  152. "execution_count": 7,
  153. "metadata": {},
  154. "outputs": [
  155. {
  156. "data": {
  157. "text/plain": [
  158. "{'a': 0,\n",
  159. " 'b': 1,\n",
  160. " 'c': 2,\n",
  161. " 'd': 3,\n",
  162. " 'e': 4,\n",
  163. " 'f': 5,\n",
  164. " 'g': 6,\n",
  165. " 'h': 7,\n",
  166. " 'i': 8,\n",
  167. " 'j': 9,\n",
  168. " 'k': 10,\n",
  169. " 'l': 11,\n",
  170. " 'm': 12,\n",
  171. " 'n': 13,\n",
  172. " 'o': 14,\n",
  173. " 'p': 15,\n",
  174. " 'q': 16,\n",
  175. " 'r': 17,\n",
  176. " 's': 18,\n",
  177. " 't': 19,\n",
  178. " 'u': 20,\n",
  179. " 'v': 21,\n",
  180. " 'w': 22,\n",
  181. " 'x': 23,\n",
  182. " 'y': 24,\n",
  183. " 'z': 25}"
  184. ]
  185. },
  186. "execution_count": 7,
  187. "metadata": {},
  188. "output_type": "execute_result"
  189. }
  190. ],
  191. "source": [
  192. "char_to_idx"
  193. ]
  194. },
  195. {
  196. "cell_type": "markdown",
  197. "metadata": {},
  198. "source": [
  199. "接着我们可以构建训练数据"
  200. ]
  201. },
  202. {
  203. "cell_type": "code",
  204. "execution_count": 8,
  205. "metadata": {
  206. "collapsed": true
  207. },
  208. "outputs": [],
  209. "source": [
  210. "def make_sequence(x, dic): # 字符编码\n",
  211. " idx = [dic[i.lower()] for i in x]\n",
  212. " idx = torch.LongTensor(idx)\n",
  213. " return idx"
  214. ]
  215. },
  216. {
  217. "cell_type": "code",
  218. "execution_count": 9,
  219. "metadata": {},
  220. "outputs": [
  221. {
  222. "data": {
  223. "text/plain": [
  224. "\n",
  225. " 0\n",
  226. " 15\n",
  227. " 15\n",
  228. " 11\n",
  229. " 4\n",
  230. "[torch.LongTensor of size 5]"
  231. ]
  232. },
  233. "execution_count": 9,
  234. "metadata": {},
  235. "output_type": "execute_result"
  236. }
  237. ],
  238. "source": [
  239. "make_sequence('apple', char_to_idx)"
  240. ]
  241. },
  242. {
  243. "cell_type": "code",
  244. "execution_count": 10,
  245. "metadata": {},
  246. "outputs": [
  247. {
  248. "data": {
  249. "text/plain": [
  250. "['Everybody', 'read', 'that', 'book']"
  251. ]
  252. },
  253. "execution_count": 10,
  254. "metadata": {},
  255. "output_type": "execute_result"
  256. }
  257. ],
  258. "source": [
  259. "training_data[1][0]"
  260. ]
  261. },
  262. {
  263. "cell_type": "code",
  264. "execution_count": 11,
  265. "metadata": {},
  266. "outputs": [
  267. {
  268. "data": {
  269. "text/plain": [
  270. "\n",
  271. " 4\n",
  272. " 5\n",
  273. " 6\n",
  274. " 7\n",
  275. "[torch.LongTensor of size 4]"
  276. ]
  277. },
  278. "execution_count": 11,
  279. "metadata": {},
  280. "output_type": "execute_result"
  281. }
  282. ],
  283. "source": [
  284. "make_sequence(training_data[1][0], word_to_idx)"
  285. ]
  286. },
  287. {
  288. "cell_type": "markdown",
  289. "metadata": {},
  290. "source": [
  291. "构建单个字符的 lstm 模型"
  292. ]
  293. },
  294. {
  295. "cell_type": "code",
  296. "execution_count": 12,
  297. "metadata": {
  298. "collapsed": true
  299. },
  300. "outputs": [],
  301. "source": [
  302. "class char_lstm(nn.Module):\n",
  303. " def __init__(self, n_char, char_dim, char_hidden):\n",
  304. " super(char_lstm, self).__init__()\n",
  305. " \n",
  306. " self.char_embed = nn.Embedding(n_char, char_dim)\n",
  307. " self.lstm = nn.LSTM(char_dim, char_hidden)\n",
  308. " \n",
  309. " def forward(self, x):\n",
  310. " x = self.char_embed(x)\n",
  311. " out, _ = self.lstm(x)\n",
  312. " return out[-1] # (batch, hidden)"
  313. ]
  314. },
  315. {
  316. "cell_type": "markdown",
  317. "metadata": {},
  318. "source": [
  319. "构建词性分类的 lstm 模型"
  320. ]
  321. },
  322. {
  323. "cell_type": "code",
  324. "execution_count": 13,
  325. "metadata": {
  326. "collapsed": true
  327. },
  328. "outputs": [],
  329. "source": [
  330. "class lstm_tagger(nn.Module):\n",
  331. " def __init__(self, n_word, n_char, char_dim, word_dim, \n",
  332. " char_hidden, word_hidden, n_tag):\n",
  333. " super(lstm_tagger, self).__init__()\n",
  334. " self.word_embed = nn.Embedding(n_word, word_dim)\n",
  335. " self.char_lstm = char_lstm(n_char, char_dim, char_hidden)\n",
  336. " self.word_lstm = nn.LSTM(word_dim + char_hidden, word_hidden)\n",
  337. " self.classify = nn.Linear(word_hidden, n_tag)\n",
  338. " \n",
  339. " def forward(self, x, word):\n",
  340. " char = []\n",
  341. " for w in word: # 对于每个单词做字符的 lstm\n",
  342. " char_list = make_sequence(w, char_to_idx)\n",
  343. " char_list = char_list.unsqueeze(1) # (seq, batch, feature) 满足 lstm 输入条件\n",
  344. " char_infor = self.char_lstm(Variable(char_list)) # (batch, char_hidden)\n",
  345. " char.append(char_infor)\n",
  346. " char = torch.stack(char, dim=0) # (seq, batch, feature)\n",
  347. " \n",
  348. " x = self.word_embed(x) # (batch, seq, word_dim)\n",
  349. " x = x.permute(1, 0, 2) # 改变顺序\n",
  350. " x = torch.cat((x, char), dim=2) # 沿着特征通道将每个词的词嵌入和字符 lstm 输出的结果拼接在一起\n",
  351. " x, _ = self.word_lstm(x)\n",
  352. " \n",
  353. " s, b, h = x.shape\n",
  354. " x = x.view(-1, h) # 重新 reshape 进行分类线性层\n",
  355. " out = self.classify(x)\n",
  356. " return out"
  357. ]
  358. },
  359. {
  360. "cell_type": "code",
  361. "execution_count": 14,
  362. "metadata": {
  363. "collapsed": true
  364. },
  365. "outputs": [],
  366. "source": [
  367. "net = lstm_tagger(len(word_to_idx), len(char_to_idx), 10, 100, 50, 128, len(tag_to_idx))\n",
  368. "criterion = nn.CrossEntropyLoss()\n",
  369. "optimizer = torch.optim.SGD(net.parameters(), lr=1e-2)"
  370. ]
  371. },
  372. {
  373. "cell_type": "code",
  374. "execution_count": 15,
  375. "metadata": {},
  376. "outputs": [
  377. {
  378. "name": "stdout",
  379. "output_type": "stream",
  380. "text": [
  381. "Epoch: 50, Loss: 0.86690\n",
  382. "Epoch: 100, Loss: 0.65471\n",
  383. "Epoch: 150, Loss: 0.45582\n",
  384. "Epoch: 200, Loss: 0.30351\n",
  385. "Epoch: 250, Loss: 0.20446\n",
  386. "Epoch: 300, Loss: 0.14376\n"
  387. ]
  388. }
  389. ],
  390. "source": [
  391. "# 开始训练\n",
  392. "for e in range(300):\n",
  393. " train_loss = 0\n",
  394. " for word, tag in training_data:\n",
  395. " word_list = make_sequence(word, word_to_idx).unsqueeze(0) # 添加第一维 batch\n",
  396. " tag = make_sequence(tag, tag_to_idx)\n",
  397. " word_list = Variable(word_list)\n",
  398. " tag = Variable(tag)\n",
  399. " # 前向传播\n",
  400. " out = net(word_list, word)\n",
  401. " loss = criterion(out, tag)\n",
  402. " train_loss += loss.data[0]\n",
  403. " # 反向传播\n",
  404. " optimizer.zero_grad()\n",
  405. " loss.backward()\n",
  406. " optimizer.step()\n",
  407. " if (e + 1) % 50 == 0:\n",
  408. " print('Epoch: {}, Loss: {:.5f}'.format(e + 1, train_loss / len(training_data)))"
  409. ]
  410. },
  411. {
  412. "cell_type": "markdown",
  413. "metadata": {},
  414. "source": [
  415. "最后我们可以看看预测的结果"
  416. ]
  417. },
  418. {
  419. "cell_type": "code",
  420. "execution_count": 19,
  421. "metadata": {
  422. "collapsed": true
  423. },
  424. "outputs": [],
  425. "source": [
  426. "net = net.eval()"
  427. ]
  428. },
  429. {
  430. "cell_type": "code",
  431. "execution_count": 25,
  432. "metadata": {},
  433. "outputs": [],
  434. "source": [
  435. "test_sent = 'Everybody ate the apple'\n",
  436. "test = make_sequence(test_sent.split(), word_to_idx).unsqueeze(0)\n",
  437. "out = net(Variable(test), test_sent.split())"
  438. ]
  439. },
  440. {
  441. "cell_type": "code",
  442. "execution_count": 27,
  443. "metadata": {},
  444. "outputs": [
  445. {
  446. "name": "stdout",
  447. "output_type": "stream",
  448. "text": [
  449. "Variable containing:\n",
  450. "-1.2148 1.9048 -0.6570\n",
  451. "-0.9272 -0.4441 1.4009\n",
  452. " 1.6425 -0.7751 -1.1553\n",
  453. "-0.6121 1.6036 -1.1280\n",
  454. "[torch.FloatTensor of size 4x3]\n",
  455. "\n"
  456. ]
  457. }
  458. ],
  459. "source": [
  460. "print(out)"
  461. ]
  462. },
  463. {
  464. "cell_type": "code",
  465. "execution_count": 28,
  466. "metadata": {},
  467. "outputs": [
  468. {
  469. "name": "stdout",
  470. "output_type": "stream",
  471. "text": [
  472. "{'det': 0, 'nn': 1, 'v': 2}\n"
  473. ]
  474. }
  475. ],
  476. "source": [
  477. "print(tag_to_idx)"
  478. ]
  479. },
  480. {
  481. "cell_type": "markdown",
  482. "metadata": {},
  483. "source": [
  484. "最后可以得到上面的结果,因为最后一层的线性层没有使用 softmax,所以数值不太像一个概率,但是每一行数值最大的就表示属于该类,可以看到第一个单词 'Everybody' 属于 nn,第二个单词 'ate' 属于 v,第三个单词 'the' 属于det,第四个单词 'apple' 属于 nn,所以得到的这个预测结果是正确的"
  485. ]
  486. }
  487. ],
  488. "metadata": {
  489. "kernelspec": {
  490. "display_name": "Python 3",
  491. "language": "python",
  492. "name": "python3"
  493. },
  494. "language_info": {
  495. "codemirror_mode": {
  496. "name": "ipython",
  497. "version": 3
  498. },
  499. "file_extension": ".py",
  500. "mimetype": "text/x-python",
  501. "name": "python",
  502. "nbconvert_exporter": "python",
  503. "pygments_lexer": "ipython3",
  504. "version": "3.6.8"
  505. }
  506. },
  507. "nbformat": 4,
  508. "nbformat_minor": 2
  509. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。