You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tutorial_9_callback.ipynb 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# 使用 Callback 自定义你的训练过程"
  8. ]
  9. },
  10. {
  11. "cell_type": "markdown",
  12. "metadata": {},
  13. "source": [
  14. "- 什么是 Callback\n",
  15. "- 使用 Callback \n",
  16. "- 一些常用的 Callback\n",
  17. "- 自定义实现 Callback"
  18. ]
  19. },
  20. {
  21. "cell_type": "markdown",
  22. "metadata": {},
  23. "source": [
  24. "什么是Callback\n",
  25. "------\n",
  26. "\n",
  27. "Callback 是与 Trainer 紧密结合的模块,利用 Callback 可以在 Trainer 训练时,加入自定义的操作,比如梯度裁剪,学习率调节,测试模型的性能等。定义的 Callback 会在训练的特定阶段被调用。\n",
  28. "\n",
  29. "fastNLP 中提供了很多常用的 Callback ,开箱即用。"
  30. ]
  31. },
  32. {
  33. "cell_type": "markdown",
  34. "metadata": {},
  35. "source": [
  36. "使用 Callback\n",
  37. " ------\n",
  38. "\n",
  39. "使用 Callback 很简单,将需要的 callback 按 list 存储,以对应参数 ``callbacks`` 传入对应的 Trainer。Trainer 在训练时就会自动执行这些 Callback 指定的操作了。"
  40. ]
  41. },
  42. {
  43. "cell_type": "code",
  44. "execution_count": 4,
  45. "metadata": {
  46. "ExecuteTime": {
  47. "end_time": "2019-09-17T07:34:46.465871Z",
  48. "start_time": "2019-09-17T07:34:30.648758Z"
  49. }
  50. },
  51. "outputs": [
  52. {
  53. "name": "stdout",
  54. "output_type": "stream",
  55. "text": [
  56. "In total 3 datasets:\n",
  57. "\ttest has 1200 instances.\n",
  58. "\ttrain has 9600 instances.\n",
  59. "\tdev has 1200 instances.\n",
  60. "In total 2 vocabs:\n",
  61. "\tchars has 4409 entries.\n",
  62. "\ttarget has 2 entries.\n",
  63. "\n",
  64. "training epochs started 2019-09-17-03-34-34\n"
  65. ]
  66. },
  67. {
  68. "data": {
  69. "application/vnd.jupyter.widget-view+json": {
  70. "model_id": "",
  71. "version_major": 2,
  72. "version_minor": 0
  73. },
  74. "text/plain": [
  75. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
  76. ]
  77. },
  78. "metadata": {},
  79. "output_type": "display_data"
  80. },
  81. {
  82. "data": {
  83. "application/vnd.jupyter.widget-view+json": {
  84. "model_id": "",
  85. "version_major": 2,
  86. "version_minor": 0
  87. },
  88. "text/plain": [
  89. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  90. ]
  91. },
  92. "metadata": {},
  93. "output_type": "display_data"
  94. },
  95. {
  96. "name": "stdout",
  97. "output_type": "stream",
  98. "text": [
  99. "Evaluate data in 0.1 seconds!\n",
  100. "Evaluation on dev at Epoch 1/3. Step:300/900: \n",
  101. "AccuracyMetric: acc=0.863333\n",
  102. "\n"
  103. ]
  104. },
  105. {
  106. "data": {
  107. "application/vnd.jupyter.widget-view+json": {
  108. "model_id": "",
  109. "version_major": 2,
  110. "version_minor": 0
  111. },
  112. "text/plain": [
  113. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  114. ]
  115. },
  116. "metadata": {},
  117. "output_type": "display_data"
  118. },
  119. {
  120. "name": "stdout",
  121. "output_type": "stream",
  122. "text": [
  123. "Evaluate data in 0.11 seconds!\n",
  124. "Evaluation on dev at Epoch 2/3. Step:600/900: \n",
  125. "AccuracyMetric: acc=0.886667\n",
  126. "\n"
  127. ]
  128. },
  129. {
  130. "data": {
  131. "application/vnd.jupyter.widget-view+json": {
  132. "model_id": "",
  133. "version_major": 2,
  134. "version_minor": 0
  135. },
  136. "text/plain": [
  137. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  138. ]
  139. },
  140. "metadata": {},
  141. "output_type": "display_data"
  142. },
  143. {
  144. "name": "stdout",
  145. "output_type": "stream",
  146. "text": [
  147. "Evaluate data in 0.1 seconds!\n",
  148. "Evaluation on dev at Epoch 3/3. Step:900/900: \n",
  149. "AccuracyMetric: acc=0.890833\n",
  150. "\n",
  151. "\r\n",
  152. "In Epoch:3/Step:900, got best dev performance:\n",
  153. "AccuracyMetric: acc=0.890833\n",
  154. "Reloaded the best model.\n"
  155. ]
  156. }
  157. ],
  158. "source": [
  159. "from fastNLP import (Callback, EarlyStopCallback,\n",
  160. " Trainer, CrossEntropyLoss, AccuracyMetric)\n",
  161. "from fastNLP.models import CNNText\n",
  162. "import torch.cuda\n",
  163. "\n",
  164. "# prepare data\n",
  165. "def get_data():\n",
  166. " from fastNLP.io import ChnSentiCorpPipe as pipe\n",
  167. " data = pipe().process_from_file()\n",
  168. " print(data)\n",
  169. " data.rename_field('chars', 'words')\n",
  170. " train_data = data.datasets['train']\n",
  171. " dev_data = data.datasets['dev']\n",
  172. " test_data = data.datasets['test']\n",
  173. " vocab = data.vocabs['words']\n",
  174. " tgt_vocab = data.vocabs['target']\n",
  175. " return train_data, dev_data, test_data, vocab, tgt_vocab\n",
  176. "\n",
  177. "# prepare model\n",
  178. "train_data, dev_data, _, vocab, tgt_vocab = get_data()\n",
  179. "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
  180. "model = CNNText((len(vocab),50), num_classes=len(tgt_vocab))\n",
  181. "\n",
  182. "# define callback\n",
  183. "callbacks=[EarlyStopCallback(5)]\n",
  184. "\n",
  185. "# pass callbacks to Trainer\n",
  186. "def train_with_callback(cb_list):\n",
  187. " trainer = Trainer(\n",
  188. " device=device,\n",
  189. " n_epochs=3,\n",
  190. " model=model, \n",
  191. " train_data=train_data, \n",
  192. " dev_data=dev_data, \n",
  193. " loss=CrossEntropyLoss(), \n",
  194. " metrics=AccuracyMetric(), \n",
  195. " callbacks=cb_list, \n",
  196. " check_code_level=-1\n",
  197. " )\n",
  198. " trainer.train()\n",
  199. "\n",
  200. "train_with_callback(callbacks)"
  201. ]
  202. },
  203. {
  204. "cell_type": "markdown",
  205. "metadata": {},
  206. "source": [
  207. "fastNLP 中的 Callback\n",
  208. "-------\n",
  209. "fastNLP 中提供了很多常用的 Callback,如梯度裁剪,训练时早停和测试验证集,fitlog 等等。具体 Callback 请参考 fastNLP.core.callbacks"
  210. ]
  211. },
  212. {
  213. "cell_type": "code",
  214. "execution_count": 5,
  215. "metadata": {
  216. "ExecuteTime": {
  217. "end_time": "2019-09-17T07:35:02.182727Z",
  218. "start_time": "2019-09-17T07:34:49.443863Z"
  219. }
  220. },
  221. "outputs": [
  222. {
  223. "name": "stdout",
  224. "output_type": "stream",
  225. "text": [
  226. "training epochs started 2019-09-17-03-34-49\n"
  227. ]
  228. },
  229. {
  230. "data": {
  231. "application/vnd.jupyter.widget-view+json": {
  232. "model_id": "",
  233. "version_major": 2,
  234. "version_minor": 0
  235. },
  236. "text/plain": [
  237. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
  238. ]
  239. },
  240. "metadata": {},
  241. "output_type": "display_data"
  242. },
  243. {
  244. "data": {
  245. "application/vnd.jupyter.widget-view+json": {
  246. "model_id": "",
  247. "version_major": 2,
  248. "version_minor": 0
  249. },
  250. "text/plain": [
  251. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  252. ]
  253. },
  254. "metadata": {},
  255. "output_type": "display_data"
  256. },
  257. {
  258. "name": "stdout",
  259. "output_type": "stream",
  260. "text": [
  261. "Evaluate data in 0.13 seconds!\n"
  262. ]
  263. },
  264. {
  265. "data": {
  266. "application/vnd.jupyter.widget-view+json": {
  267. "model_id": "",
  268. "version_major": 2,
  269. "version_minor": 0
  270. },
  271. "text/plain": [
  272. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  273. ]
  274. },
  275. "metadata": {},
  276. "output_type": "display_data"
  277. },
  278. {
  279. "name": "stdout",
  280. "output_type": "stream",
  281. "text": [
  282. "Evaluate data in 0.12 seconds!\n",
  283. "Evaluation on data-test:\n",
  284. "AccuracyMetric: acc=0.890833\n",
  285. "Evaluation on dev at Epoch 1/3. Step:300/900: \n",
  286. "AccuracyMetric: acc=0.890833\n",
  287. "\n"
  288. ]
  289. },
  290. {
  291. "data": {
  292. "application/vnd.jupyter.widget-view+json": {
  293. "model_id": "",
  294. "version_major": 2,
  295. "version_minor": 0
  296. },
  297. "text/plain": [
  298. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  299. ]
  300. },
  301. "metadata": {},
  302. "output_type": "display_data"
  303. },
  304. {
  305. "name": "stdout",
  306. "output_type": "stream",
  307. "text": [
  308. "Evaluate data in 0.09 seconds!\n"
  309. ]
  310. },
  311. {
  312. "data": {
  313. "application/vnd.jupyter.widget-view+json": {
  314. "model_id": "",
  315. "version_major": 2,
  316. "version_minor": 0
  317. },
  318. "text/plain": [
  319. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  320. ]
  321. },
  322. "metadata": {},
  323. "output_type": "display_data"
  324. },
  325. {
  326. "name": "stdout",
  327. "output_type": "stream",
  328. "text": [
  329. "Evaluate data in 0.09 seconds!\n",
  330. "Evaluation on data-test:\n",
  331. "AccuracyMetric: acc=0.8875\n",
  332. "Evaluation on dev at Epoch 2/3. Step:600/900: \n",
  333. "AccuracyMetric: acc=0.8875\n",
  334. "\n"
  335. ]
  336. },
  337. {
  338. "data": {
  339. "application/vnd.jupyter.widget-view+json": {
  340. "model_id": "",
  341. "version_major": 2,
  342. "version_minor": 0
  343. },
  344. "text/plain": [
  345. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  346. ]
  347. },
  348. "metadata": {},
  349. "output_type": "display_data"
  350. },
  351. {
  352. "name": "stdout",
  353. "output_type": "stream",
  354. "text": [
  355. "Evaluate data in 0.11 seconds!\n"
  356. ]
  357. },
  358. {
  359. "data": {
  360. "application/vnd.jupyter.widget-view+json": {
  361. "model_id": "",
  362. "version_major": 2,
  363. "version_minor": 0
  364. },
  365. "text/plain": [
  366. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  367. ]
  368. },
  369. "metadata": {},
  370. "output_type": "display_data"
  371. },
  372. {
  373. "name": "stdout",
  374. "output_type": "stream",
  375. "text": [
  376. "Evaluate data in 0.1 seconds!\n",
  377. "Evaluation on data-test:\n",
  378. "AccuracyMetric: acc=0.885\n",
  379. "Evaluation on dev at Epoch 3/3. Step:900/900: \n",
  380. "AccuracyMetric: acc=0.885\n",
  381. "\n",
  382. "\r\n",
  383. "In Epoch:1/Step:300, got best dev performance:\n",
  384. "AccuracyMetric: acc=0.890833\n",
  385. "Reloaded the best model.\n"
  386. ]
  387. }
  388. ],
  389. "source": [
  390. "from fastNLP import EarlyStopCallback, GradientClipCallback, EvaluateCallback\n",
  391. "callbacks = [\n",
  392. " EarlyStopCallback(5),\n",
  393. " GradientClipCallback(clip_value=5, clip_type='value'),\n",
  394. " EvaluateCallback(dev_data)\n",
  395. "]\n",
  396. "\n",
  397. "train_with_callback(callbacks)"
  398. ]
  399. },
  400. {
  401. "cell_type": "markdown",
  402. "metadata": {},
  403. "source": [
  404. "自定义 Callback\n",
  405. "------\n",
  406. "\n",
  407. "这里我们以一个简单的 Callback作为例子,它的作用是打印每一个 Epoch 平均训练 loss。\n",
  408. "\n",
  409. "#### 创建 Callback\n",
  410. " \n",
  411. "要自定义 Callback,我们要实现一个类,继承 fastNLP.Callback。\n",
  412. "\n",
  413. "这里我们定义 MyCallBack ,继承 fastNLP.Callback 。\n",
  414. "\n",
  415. "#### 指定 Callback 调用的阶段\n",
  416. " \n",
  417. "Callback 中所有以 on_ 开头的类方法会在 Trainer 的训练中在特定阶段调用。 如 on_train_begin() 会在训练开始时被调用,on_epoch_end() 会在每个 epoch 结束时调用。 具体有哪些类方法,参见 Callback 文档。\n",
  418. "\n",
  419. "这里, MyCallBack 在求得loss时调用 on_backward_begin() 记录当前 loss ,在每一个 epoch 结束时调用 on_epoch_end() ,求当前 epoch 平均loss并输出。\n",
  420. "\n",
  421. "#### 使用 Callback 的属性访问 Trainer 的内部信息\n",
  422. " \n",
  423. "为了方便使用,可以使用 Callback 的属性,访问 Trainer 中的对应信息,如 optimizer, epoch, n_epochs,分别对应训练时的优化器,当前 epoch 数,和总 epoch 数。 具体可访问的属性,参见文档 Callback 。\n",
  424. "\n",
  425. "这里, MyCallBack 为了求平均 loss ,需要知道当前 epoch 的总步数,可以通过 self.step 属性得到当前训练了多少步。\n",
  426. "\n"
  427. ]
  428. },
  429. {
  430. "cell_type": "code",
  431. "execution_count": 8,
  432. "metadata": {
  433. "ExecuteTime": {
  434. "end_time": "2019-09-17T07:43:10.907139Z",
  435. "start_time": "2019-09-17T07:42:58.488177Z"
  436. }
  437. },
  438. "outputs": [
  439. {
  440. "name": "stdout",
  441. "output_type": "stream",
  442. "text": [
  443. "training epochs started 2019-09-17-03-42-58\n"
  444. ]
  445. },
  446. {
  447. "data": {
  448. "application/vnd.jupyter.widget-view+json": {
  449. "model_id": "",
  450. "version_major": 2,
  451. "version_minor": 0
  452. },
  453. "text/plain": [
  454. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
  455. ]
  456. },
  457. "metadata": {},
  458. "output_type": "display_data"
  459. },
  460. {
  461. "data": {
  462. "application/vnd.jupyter.widget-view+json": {
  463. "model_id": "",
  464. "version_major": 2,
  465. "version_minor": 0
  466. },
  467. "text/plain": [
  468. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  469. ]
  470. },
  471. "metadata": {},
  472. "output_type": "display_data"
  473. },
  474. {
  475. "name": "stdout",
  476. "output_type": "stream",
  477. "text": [
  478. "Evaluate data in 0.11 seconds!\n",
  479. "Evaluation on dev at Epoch 1/3. Step:300/900: \n",
  480. "AccuracyMetric: acc=0.883333\n",
  481. "\n",
  482. "Avg loss at epoch 1, 0.100254\n"
  483. ]
  484. },
  485. {
  486. "data": {
  487. "application/vnd.jupyter.widget-view+json": {
  488. "model_id": "",
  489. "version_major": 2,
  490. "version_minor": 0
  491. },
  492. "text/plain": [
  493. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  494. ]
  495. },
  496. "metadata": {},
  497. "output_type": "display_data"
  498. },
  499. {
  500. "name": "stdout",
  501. "output_type": "stream",
  502. "text": [
  503. "Evaluate data in 0.1 seconds!\n",
  504. "Evaluation on dev at Epoch 2/3. Step:600/900: \n",
  505. "AccuracyMetric: acc=0.8775\n",
  506. "\n",
  507. "Avg loss at epoch 2, 0.183511\n"
  508. ]
  509. },
  510. {
  511. "data": {
  512. "application/vnd.jupyter.widget-view+json": {
  513. "model_id": "",
  514. "version_major": 2,
  515. "version_minor": 0
  516. },
  517. "text/plain": [
  518. "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
  519. ]
  520. },
  521. "metadata": {},
  522. "output_type": "display_data"
  523. },
  524. {
  525. "name": "stdout",
  526. "output_type": "stream",
  527. "text": [
  528. "Evaluate data in 0.13 seconds!\n",
  529. "Evaluation on dev at Epoch 3/3. Step:900/900: \n",
  530. "AccuracyMetric: acc=0.875833\n",
  531. "\n",
  532. "Avg loss at epoch 3, 0.257103\n",
  533. "\r\n",
  534. "In Epoch:1/Step:300, got best dev performance:\n",
  535. "AccuracyMetric: acc=0.883333\n",
  536. "Reloaded the best model.\n"
  537. ]
  538. }
  539. ],
  540. "source": [
  541. "from fastNLP import Callback\n",
  542. "from fastNLP import logger\n",
  543. "\n",
  544. "class MyCallBack(Callback):\n",
  545. " \"\"\"Print average loss in each epoch\"\"\"\n",
  546. " def __init__(self):\n",
  547. " super().__init__()\n",
  548. " self.total_loss = 0\n",
  549. " self.start_step = 0\n",
  550. " \n",
  551. " def on_backward_begin(self, loss):\n",
  552. " self.total_loss += loss.item()\n",
  553. " \n",
  554. " def on_epoch_end(self):\n",
  555. " n_steps = self.step - self.start_step\n",
  556. " avg_loss = self.total_loss / n_steps\n",
  557. " logger.info('Avg loss at epoch %d, %.6f', self.epoch, avg_loss)\n",
  558. " self.start_step = self.step\n",
  559. "\n",
  560. "callbacks = [MyCallBack()]\n",
  561. "train_with_callback(callbacks)"
  562. ]
  563. },
  564. {
  565. "cell_type": "code",
  566. "execution_count": null,
  567. "metadata": {},
  568. "outputs": [],
  569. "source": []
  570. }
  571. ],
  572. "metadata": {
  573. "kernelspec": {
  574. "display_name": "Python 3",
  575. "language": "python",
  576. "name": "python3"
  577. },
  578. "language_info": {
  579. "codemirror_mode": {
  580. "name": "ipython",
  581. "version": 3
  582. },
  583. "file_extension": ".py",
  584. "mimetype": "text/x-python",
  585. "name": "python",
  586. "nbconvert_exporter": "python",
  587. "pygments_lexer": "ipython3",
  588. "version": "3.7.3"
  589. },
  590. "varInspector": {
  591. "cols": {
  592. "lenName": 16,
  593. "lenType": 16,
  594. "lenVar": 40
  595. },
  596. "kernels_config": {
  597. "python": {
  598. "delete_cmd_postfix": "",
  599. "delete_cmd_prefix": "del ",
  600. "library": "var_list.py",
  601. "varRefreshCmd": "print(var_dic_list())"
  602. },
  603. "r": {
  604. "delete_cmd_postfix": ") ",
  605. "delete_cmd_prefix": "rm(",
  606. "library": "var_list.r",
  607. "varRefreshCmd": "cat(var_dic_list()) "
  608. }
  609. },
  610. "types_to_exclude": [
  611. "module",
  612. "function",
  613. "builtin_function_or_method",
  614. "instance",
  615. "_Feature"
  616. ],
  617. "window_display": false
  618. }
  619. },
  620. "nbformat": 4,
  621. "nbformat_minor": 4
  622. }