You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_paddle_e2.ipynb 98 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# 使用 paddlenlp 和 FastNLP 训练中文阅读理解任务\n",
  8. "\n",
  9. "本篇教程属于 **`FastNLP v0.8 tutorial` 的 `paddle examples` 系列**。在本篇教程中,我们将为您展示如何在 `FastNLP` 中通过自定义 `Metric` 和 损失函数来完成进阶的问答任务。\n",
  10. "\n",
  11. "1. 基础介绍:自然语言处理中的阅读理解任务\n",
  12. "\n",
  13. "2. 准备工作:加载 `DuReader-robust` 数据集,并使用 `tokenizer` 处理数据\n",
  14. "\n",
  15. "3. 模型训练:自己定义评测用的 `Metric` 实现更加自由的任务评测"
  16. ]
  17. },
  18. {
  19. "cell_type": "markdown",
  20. "metadata": {},
  21. "source": [
  22. "### 1. 基础介绍:自然语言处理中的阅读理解任务\n",
  23. "\n",
  24. "阅读理解任务,顾名思义,就是给出一段文字,然后让模型理解这段文字所含的语义。大部分机器阅读理解任务都采用问答式测评,即设计与文章内容相关的自然语言式问题,让模型理解问题并根据文章作答。与文本分类任务不同的是,在阅读理解任务中我们有时需要需要输入“一对”句子,分别代表问题和上下文;答案的格式也分为多种:\n",
  25. "\n",
  26. "- 多项选择:让模型从多个答案选项中选出正确答案\n",
  27. "- 区间答案:答案为上下文的一段子句,需要模型给出答案的起始位置\n",
  28. "- 自由回答:不做限制,让模型自行生成答案\n",
  29. "- 完形填空:在原文中挖空部分关键词,让模型补全;这类答案往往不需要问题\n",
  30. "\n",
  31. "如果您对 `transformers` 有所了解的话,其中的 `ModelForQuestionAnswering` 系列模型就可以用于这项任务。阅读理解模型的泛用性是衡量该技术能否在实际应用中大规模落地的重要指标之一,随着当前技术的进步,许多模型虽然能够在一些测试集上取得较好的性能,但在实际应用中,这些模型仍然难以让人满意。在本篇教程中,我们将会为您展示如何训练一个问答模型。\n",
  32. "\n",
  33. "在这一领域,`SQuAD` 数据集是一个影响深远的数据集。它的全称是斯坦福问答数据集(Stanford Question Answering Dataset),每条数据包含 `(问题,上下文,答案)` 三部分,规模大(约十万条,2.0又新增了五万条),在提出之后很快成为训练问答任务的经典数据集之一。`SQuAD` 数据集有两个指标来衡量模型的表现:`EM`(Exact Match,精确匹配)和 `F1`(模糊匹配)。前者反应了模型给出的答案中有多少和正确答案完全一致,后者则反应了模型给出的答案中与正确答案重叠的部分,均为越高越好。"
  34. ]
  35. },
  36. {
  37. "cell_type": "markdown",
  38. "metadata": {},
  39. "source": [
  40. "### 2. 准备工作:加载 DuReader-robust 数据集,并使用 tokenizer 处理数据"
  41. ]
  42. },
  43. {
  44. "cell_type": "code",
  45. "execution_count": 1,
  46. "metadata": {},
  47. "outputs": [
  48. {
  49. "name": "stderr",
  50. "output_type": "stream",
  51. "text": [
  52. "/remote-home/shxing/anaconda3/envs/fnlp-paddle/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
  53. " from .autonotebook import tqdm as notebook_tqdm\n"
  54. ]
  55. },
  56. {
  57. "name": "stdout",
  58. "output_type": "stream",
  59. "text": [
  60. "2.3.3\n"
  61. ]
  62. }
  63. ],
  64. "source": [
  65. "import sys\n",
  66. "sys.path.append(\"../\")\n",
  67. "import paddle\n",
  68. "import paddlenlp\n",
  69. "\n",
  70. "print(paddlenlp.__version__)"
  71. ]
  72. },
  73. {
  74. "cell_type": "markdown",
  75. "metadata": {},
  76. "source": [
  77. "在数据集方面,我们选用 `DuReader-robust` 中文数据集作为训练数据。它是一种抽取式问答数据集,采用 `SQuAD` 数据格式,能够评估真实应用场景下模型的泛用性。"
  78. ]
  79. },
  80. {
  81. "cell_type": "code",
  82. "execution_count": 17,
  83. "metadata": {},
  84. "outputs": [
  85. {
  86. "name": "stderr",
  87. "output_type": "stream",
  88. "text": [
  89. "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n",
  90. "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n",
  91. "\u001b[32m[2022-06-27 19:22:46,998] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/vocab.txt\u001b[0m\n"
  92. ]
  93. },
  94. {
  95. "name": "stdout",
  96. "output_type": "stream",
  97. "text": [
  98. "{'id': '0a25cb4bc1ab6f474c699884e04601e4', 'title': '', 'context': '第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。', 'question': '仙剑奇侠传3第几集上天界', 'answers': {'text': ['第35集'], 'answer_start': [0]}}\n",
  99. "{'id': '7de192d6adf7d60ba73ba25cf590cc1e', 'title': '', 'context': '选择燃气热水器时,一定要关注这几个问题:1、出水稳定性要好,不能出现忽热忽冷的现象2、快速到达设定的需求水温3、操作要智能、方便4、安全性要好,要装有安全报警装置 市场上燃气热水器品牌众多,购买时还需多加对比和仔细鉴别。方太今年主打的磁化恒温热水器在使用体验方面做了全面升级:9秒速热,可快速进入洗浴模式;水温持久稳定,不会出现忽热忽冷的现象,并通过水量伺服技术将出水温度精确控制在±0.5℃,可满足家里宝贝敏感肌肤洗护需求;配备CO和CH4双气体报警装置更安全(市场上一般多为CO单气体报警)。另外,这款热水器还有智能WIFI互联功能,只需下载个手机APP即可用手机远程操作热水器,实现精准调节水温,满足家人多样化的洗浴需求。当然方太的磁化恒温系列主要的是增加磁化功能,可以有效吸附水中的铁锈、铁屑等微小杂质,防止细菌滋生,使沐浴水质更洁净,长期使用磁化水沐浴更利于身体健康。', 'question': '燃气热水器哪个牌子好', 'answers': {'text': ['方太'], 'answer_start': [110]}}\n",
  100. "{'id': 'b9e74d4b9228399b03701d1fe6d52940', 'title': '', 'context': '迈克尔.乔丹在NBA打了15个赛季。他在84年进入nba,期间在1993年10月6日第一次退役改打棒球,95年3月18日重新回归,在99年1月13日第二次退役,后于2001年10月31日复出,在03年最终退役。迈克尔·乔丹(Michael Jordan),1963年2月17日生于纽约布鲁克林,美国著名篮球运动员,司职得分后卫,历史上最伟大的篮球运动员。1984年的NBA选秀大会,乔丹在首轮第3顺位被芝加哥公牛队选中。 1986-87赛季,乔丹场均得到37.1分,首次获得分王称号。1990-91赛季,乔丹连夺常规赛MVP和总决赛MVP称号,率领芝加哥公牛首次夺得NBA总冠军。 1997-98赛季,乔丹获得个人职业生涯第10个得分王,并率领公牛队第六次夺得总冠军。2009年9月11日,乔丹正式入选NBA名人堂。', 'question': '乔丹打了多少个赛季', 'answers': {'text': ['15个'], 'answer_start': [12]}}\n",
  101. "训练集大小: 14520\n",
  102. "验证集大小: 1417\n"
  103. ]
  104. }
  105. ],
  106. "source": [
  107. "from paddlenlp.datasets import load_dataset\n",
  108. "train_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"train\")\n",
  109. "val_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"validation\")\n",
  110. "for i in range(3):\n",
  111. " print(train_dataset[i])\n",
  112. "print(\"训练集大小:\", len(train_dataset))\n",
  113. "print(\"验证集大小:\", len(val_dataset))\n",
  114. "\n",
  115. "MODEL_NAME = \"ernie-1.0-base-zh\"\n",
  116. "from paddlenlp.transformers import ErnieTokenizer\n",
  117. "tokenizer =ErnieTokenizer.from_pretrained(MODEL_NAME)"
  118. ]
  119. },
  120. {
  121. "cell_type": "markdown",
  122. "metadata": {},
  123. "source": [
  124. "#### 2.1 处理训练集\n",
  125. "\n",
  126. "对于阅读理解任务,数据处理的方式较为麻烦。接下来我们会为您详细讲解处理函数 `_process_train` 的功能,同时也将通过实践展示关于 `tokenizer` 的更多功能,让您更加深入地了解自然语言处理任务。首先让我们向 `tokenizer` 输入一条数据(以列表的形式):"
  127. ]
  128. },
  129. {
  130. "cell_type": "code",
  131. "execution_count": 3,
  132. "metadata": {},
  133. "outputs": [
  134. {
  135. "name": "stdout",
  136. "output_type": "stream",
  137. "text": [
  138. "2\n",
  139. "dict_keys(['offset_mapping', 'input_ids', 'token_type_ids', 'overflow_to_sample'])\n"
  140. ]
  141. }
  142. ],
  143. "source": [
  144. "result = tokenizer(\n",
  145. " [train_dataset[0][\"question\"]],\n",
  146. " [train_dataset[0][\"context\"]],\n",
  147. " stride=128,\n",
  148. " max_length=256,\n",
  149. " padding=\"max_length\",\n",
  150. " return_dict=False\n",
  151. ")\n",
  152. "\n",
  153. "print(len(result))\n",
  154. "print(result[0].keys())"
  155. ]
  156. },
  157. {
  158. "cell_type": "markdown",
  159. "metadata": {},
  160. "source": [
  161. "首先不难理解的是,模型必须要同时接受问题(`question`)和上下文(`context`)才能够进行阅读理解,因此我们需要将二者同时进行分词(`tokenize`)。所幸,`Tokenizer` 提供了这一功能,当我们调用 `tokenizer` 的时候,其第一个参数名为 `text`,第二个参数名为 `text_pair`,这使得我们可以同时对一对文本进行分词。同时,`tokenizer` 还需要标记出一条数据中哪些属于问题,哪些属于上下文,这一功能则由 `token_type_ids` 完成。`token_type_ids` 会将输入的第一个文本(问题)标记为 `0`,第二个文本(上下文)标记为 `1`,这样模型在训练时便可以将问题和上下文区分开来:"
  162. ]
  163. },
  164. {
  165. "cell_type": "code",
  166. "execution_count": 4,
  167. "metadata": {},
  168. "outputs": [
  169. {
  170. "name": "stdout",
  171. "output_type": "stream",
  172. "text": [
  173. "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2]\n",
  174. "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓', '缓', '张', '开', '眼', '睛', ',', '景', '天', '又', '惊', '又', '喜', '之', '际', ',', '长', '卿', '和', '紫', '萱', '的', '仙', '船', '驶', '至', ',', '见', '众', '人', '无', '恙', ',', '也', '十', '分', '高', '兴', '。', '众', '人', '登', '船', ',', '用', '尽', '合', '力', '把', '自', '身', '的', '真', '气', '和', '水', '分', '输', '给', '她', '。', '雪', '见', '终', '于', '醒', '过', '来', '了', ',', '但', '却', '一', '脸', '木', '然', ',', '全', '无', '反', '应', '。', '众', '人', '向', '常', '胤', '求', '助', ',', '却', '发', '现', '人', '世', '界', '竟', '没', '有', '雪', '见', '的', '身', '世', '纪', '录', '。', '长', '卿', '询', '问', '清', '微', '的', '身', '世', ',', '清', '微', '语', '带', '双', '关', '说', '一', '切', '上', '了', '天', '界', '便', '有', '答', '案', '。', '长', '卿', '驾', '驶', '仙', '船', ',', '众', '人', '决', '定', '立', '马', '动', '身', ',', '往', '天', '界', '而', '去', '。', '众', '人', '来', '到', '一', '荒', '山', ',', '长', '卿', '指', '出', ',', '魔', '界', '和', '天', '界', '相', '连', '。', '由', '魔', '界', '进', '入', '通', '过', '神', '魔', '之', '井', ',', '便', '可', '登', '天', '。', '众', '人', '至', '魔', '界', '入', '口', ',', '仿', '若', '一', '黑', '色', '的', '蝙', '蝠', '洞', ',', '但', '始', '终', '无', '法', '进', '入', '。', '后', '来', '花', '楹', '发', '现', '只', '要', '有', '翅', '膀', '便', '能', '飞', '入', '[SEP]']\n",
  175. "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n"
  176. ]
  177. }
  178. ],
  179. "source": [
  180. "print(result[0][\"input_ids\"])\n",
  181. "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"]))\n",
  182. "print(result[0][\"token_type_ids\"])"
  183. ]
  184. },
  185. {
  186. "cell_type": "markdown",
  187. "metadata": {},
  188. "source": [
  189. "根据上面的输出我们可以看出,`tokenizer` 会将数据开头用 `[CLS]` 标记,用 `[SEP]` 来分割句子。同时,根据 `token_type_ids` 得到的 0、1 串,我们也很容易将问题和上下文区分开。顺带一提,如果一条数据进行了 `padding`,那么这部分会被标记为 `0` 。\n",
  190. "\n",
  191. "在输出的 `keys` 中还有一项名为 `offset_mapping` 的键。该项数据能够表示分词后的每个 `token` 在原文中对应文字或词语的位置。比如我们可以像下面这样将数据打印出来:"
  192. ]
  193. },
  194. {
  195. "cell_type": "code",
  196. "execution_count": 5,
  197. "metadata": {},
  198. "outputs": [
  199. {
  200. "name": "stdout",
  201. "output_type": "stream",
  202. "text": [
  203. "[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7)]\n",
  204. "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427]\n",
  205. "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓']\n"
  206. ]
  207. }
  208. ],
  209. "source": [
  210. "print(result[0][\"offset_mapping\"][:20])\n",
  211. "print(result[0][\"input_ids\"][:20])\n",
  212. "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"])[:20])"
  213. ]
  214. },
  215. {
  216. "cell_type": "markdown",
  217. "metadata": {},
  218. "source": [
  219. "`[CLS]` 由于是 `tokenizer` 自己添加进去用于标记数据的 `token`,因此它在原文中找不到任何对应的词语,所以给出的位置范围就是 `(0, 0)`;第二个 `token` 对应第一个 `“仙”` 字,因此映射的位置就是 `(0, 1)`;同理,后面的 `[SEP]` 也不对应任何文字,映射的位置为 `(0, 0)`;而接下来的 `token` 对应 **上下文** 中的第一个字 `“第”`,映射出的位置为 `(0, 1)`;再后面的 `token` 对应原文中的两个字符 `35`,因此其位置映射为 `(1, 3)` 。通过这种手段,我们可以更方便地获取 `token` 与原文的对应关系。\n",
  220. "\n",
  221. "最后,您也许会注意到我们获取的 `result` 长度为 2 。这是文本在分词后长度超过了 `max_length` 256 ,`tokenizer` 将数据分成了两部分所致。在阅读理解任务中,我们不可能像文本分类那样轻易地将一条数据截断,因为答案很可能就出现在后面被丢弃的那部分数据中,因此,我们需要保留所有的数据(当然,您也可以直接丢弃这些超长的数据)。`overflow_to_sample` 则可以标识当前数据在原数据的索引:"
  222. ]
  223. },
  224. {
  225. "cell_type": "code",
  226. "execution_count": 6,
  227. "metadata": {},
  228. "outputs": [
  229. {
  230. "name": "stdout",
  231. "output_type": "stream",
  232. "text": [
  233. "[CLS]仙剑奇侠传3第几集上天界[SEP]第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入[SEP]\n",
  234. "overflow_to_sample: 0\n",
  235. "[CLS]仙剑奇侠传3第几集上天界[SEP]说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]\n",
  236. "overflow_to_sample: 0\n"
  237. ]
  238. }
  239. ],
  240. "source": [
  241. "for res in result:\n",
  242. " tokens = tokenizer.convert_ids_to_tokens(res[\"input_ids\"])\n",
  243. " print(\"\".join(tokens))\n",
  244. " print(\"overflow_to_sample: \", res[\"overflow_to_sample\"])"
  245. ]
  246. },
  247. {
  248. "cell_type": "markdown",
  249. "metadata": {},
  250. "source": [
  251. "将两条数据均输出之后可以看到,它们都出自我们传入的数据,并且存在一部分重合。`tokenizer` 的 `stride` 参数可以设置重合部分的长度,这也可以帮助模型识别被分割开的两条数据;`overflow_to_sample` 的 `0` 则代表它们来自于第 `0` 条数据。\n",
  252. "\n",
  253. "基于以上信息,我们处理训练集的思路如下:\n",
  254. "\n",
  255. "1. 通过 `overflow_to_sample` 来获取原来的数据\n",
  256. "2. 通过原数据的 `answers` 找到答案的起始位置\n",
  257. "3. 通过 `offset_mapping` 给出的映射关系在分词处理后的数据中找到答案的起始位置,分别记录在 `start_pos` 和 `end_pos` 中;如果没有找到答案(比如答案被截断了),那么答案的起始位置就被标记为 `[CLS]` 的位置。\n",
  258. "\n",
  259. "这样 `_process_train` 函数就呼之欲出了,我们调用 `train_dataset.map` 函数,并将 `batched` 参数设置为 `True` ,将所有数据批量地进行更新。有一点需要注意的是,**在处理过后数据量会增加**。"
  260. ]
  261. },
  262. {
  263. "cell_type": "code",
  264. "execution_count": 18,
  265. "metadata": {},
  266. "outputs": [
  267. {
  268. "name": "stdout",
  269. "output_type": "stream",
  270. "text": [
  271. "{'offset_mapping': [(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (0, 0)], 'input_ids': [1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'overflow_to_sample': 0, 'start_pos': 14, 'end_pos': 16}\n",
  272. "处理后的训练集大小: 26198\n"
  273. ]
  274. }
  275. ],
  276. "source": [
  277. "max_length = 256\n",
  278. "doc_stride = 128\n",
  279. "def _process_train(data):\n",
  280. "\n",
  281. " contexts = [data[i][\"context\"] for i in range(len(data))]\n",
  282. " questions = [data[i][\"question\"] for i in range(len(data))]\n",
  283. "\n",
  284. " tokenized_data_list = tokenizer(\n",
  285. " questions,\n",
  286. " contexts,\n",
  287. " stride=doc_stride,\n",
  288. " max_length=max_length,\n",
  289. " padding=\"max_length\",\n",
  290. " return_dict=False\n",
  291. " )\n",
  292. "\n",
  293. " for i, tokenized_data in enumerate(tokenized_data_list):\n",
  294. " # 获取 [CLS] 对应的位置\n",
  295. " input_ids = tokenized_data[\"input_ids\"]\n",
  296. " cls_index = input_ids.index(tokenizer.cls_token_id)\n",
  297. "\n",
  298. " # 在 tokenize 的过程中,汉字和 token 在位置上并非一一对应的\n",
  299. " # 而 offset mapping 记录了每个 token 在原文中对应的起始位置\n",
  300. " offsets = tokenized_data[\"offset_mapping\"]\n",
  301. " # token_type_ids 记录了一条数据中哪些是问题,哪些是上下文\n",
  302. " token_type_ids = tokenized_data[\"token_type_ids\"]\n",
  303. "\n",
  304. " # 一条数据可能因为长度过长而在 tokenized_data 中存在多个结果\n",
  305. " # overflow_to_sample 表示了当前 tokenize_example 属于 data 中的哪一条数据\n",
  306. " sample_index = tokenized_data[\"overflow_to_sample\"]\n",
  307. " answers = data[sample_index][\"answers\"]\n",
  308. "\n",
  309. " # answers 和 answer_starts 均为长度为 1 的 list\n",
  310. " # 我们可以计算出答案的结束位置\n",
  311. " start_char = answers[\"answer_start\"][0]\n",
  312. " end_char = start_char + len(answers[\"text\"][0])\n",
  313. "\n",
  314. " token_start_index = 0\n",
  315. " while token_type_ids[token_start_index] != 1:\n",
  316. " token_start_index += 1\n",
  317. "\n",
  318. " token_end_index = len(input_ids) - 1\n",
  319. " while token_type_ids[token_end_index] != 1:\n",
  320. " token_end_index -= 1\n",
  321. " # 分词后一条数据的结尾一定是 [SEP],因此还需要减一\n",
  322. " token_end_index -= 1\n",
  323. "\n",
  324. " if not (offsets[token_start_index][0] <= start_char and\n",
  325. " offsets[token_end_index][1] >= end_char):\n",
  326. " # 如果答案不在这条数据中,则将答案位置标记为 [CLS] 的位置\n",
  327. " tokenized_data_list[i][\"start_pos\"] = cls_index\n",
  328. " tokenized_data_list[i][\"end_pos\"] = cls_index\n",
  329. " else:\n",
  330. " # 否则,我们可以找到答案对应的 token 的起始位置,记录在 start_pos 和 end_pos 中\n",
  331. " while token_start_index < len(offsets) and offsets[\n",
  332. " token_start_index][0] <= start_char:\n",
  333. " token_start_index += 1\n",
  334. " tokenized_data_list[i][\"start_pos\"] = token_start_index - 1\n",
  335. " while offsets[token_end_index][1] >= end_char:\n",
  336. " token_end_index -= 1\n",
  337. " tokenized_data_list[i][\"end_pos\"] = token_end_index + 1\n",
  338. "\n",
  339. " return tokenized_data_list\n",
  340. "\n",
  341. "train_dataset.map(_process_train, batched=True, num_workers=5)\n",
  342. "print(train_dataset[0])\n",
  343. "print(\"处理后的训练集大小:\", len(train_dataset))"
  344. ]
  345. },
  346. {
  347. "cell_type": "markdown",
  348. "metadata": {},
  349. "source": [
  350. "#### 2.2 处理验证集\n",
  351. "\n",
  352. "对于验证集的处理则简单得多,我们只需要保存原数据的 `id` 并将 `offset_mapping` 中不属于上下文的部分设置为 `None` 即可。"
  353. ]
  354. },
  355. {
  356. "cell_type": "code",
  357. "execution_count": 8,
  358. "metadata": {},
  359. "outputs": [
  360. {
  361. "data": {
  362. "text/plain": [
  363. "<paddlenlp.datasets.dataset.MapDataset at 0x7f697503d7d0>"
  364. ]
  365. },
  366. "execution_count": 8,
  367. "metadata": {},
  368. "output_type": "execute_result"
  369. }
  370. ],
  371. "source": [
  372. "def _process_val(data):\n",
  373. "\n",
  374. " contexts = [data[i][\"context\"] for i in range(len(data))]\n",
  375. " questions = [data[i][\"question\"] for i in range(len(data))]\n",
  376. "\n",
  377. " tokenized_data_list = tokenizer(\n",
  378. " questions,\n",
  379. " contexts,\n",
  380. " stride=doc_stride,\n",
  381. " max_length=max_length,\n",
  382. " return_dict=False\n",
  383. " )\n",
  384. "\n",
  385. " for i, tokenized_data in enumerate(tokenized_data_list):\n",
  386. " token_type_ids = tokenized_data[\"token_type_ids\"]\n",
  387. " # 保存数据对应的 id\n",
  388. " sample_index = tokenized_data[\"overflow_to_sample\"]\n",
  389. " tokenized_data_list[i][\"example_id\"] = data[sample_index][\"id\"]\n",
  390. "\n",
  391. " # 将不属于 context 的 offset 设置为 None\n",
  392. " tokenized_data_list[i][\"offset_mapping\"] = [\n",
  393. " (o if token_type_ids[k] == 1 else None)\n",
  394. " for k, o in enumerate(tokenized_data[\"offset_mapping\"])\n",
  395. " ]\n",
  396. "\n",
  397. " return tokenized_data_list\n",
  398. "\n",
  399. "val_dataset.map(_process_val, batched=True, num_workers=5)"
  400. ]
  401. },
  402. {
  403. "cell_type": "markdown",
  404. "metadata": {},
  405. "source": [
  406. "#### 2.3 DataLoader\n",
  407. "\n",
  408. "最后使用 `PaddleDataLoader` 将数据集包裹起来即可。"
  409. ]
  410. },
  411. {
  412. "cell_type": "code",
  413. "execution_count": 9,
  414. "metadata": {},
  415. "outputs": [
  416. {
  417. "data": {
  418. "text/html": [
  419. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  420. "</pre>\n"
  421. ],
  422. "text/plain": [
  423. "\n"
  424. ]
  425. },
  426. "metadata": {},
  427. "output_type": "display_data"
  428. }
  429. ],
  430. "source": [
  431. "from fastNLP.core import PaddleDataLoader\n",
  432. "\n",
  433. "train_dataloader = PaddleDataLoader(train_dataset, batch_size=32, shuffle=True)\n",
  434. "val_dataloader = PaddleDataLoader(val_dataset, batch_size=16)"
  435. ]
  436. },
  437. {
  438. "cell_type": "markdown",
  439. "metadata": {},
  440. "source": [
  441. "### 3. 模型训练:自己定义评测用的 Metric 实现更加自由的任务评测\n",
  442. "\n",
  443. "#### 3.1 损失函数\n",
  444. "\n",
  445. "对于阅读理解任务,我们使用的是 `ErnieForQuestionAnswering` 模型。该模型在接受输入后会返回两个值:`start_logits` 和 `end_logits` ,大小均为 `(batch_size, sequence_length)`,反映了每条数据每个词语为答案起始位置的可能性,因此我们需要自定义一个损失函数来计算 `loss`。 `CrossEntropyLossForSquad` 会分别对答案起始位置的预测值和真实值计算交叉熵,最后返回其平均值作为最终的损失。"
  446. ]
  447. },
  448. {
  449. "cell_type": "code",
  450. "execution_count": 10,
  451. "metadata": {},
  452. "outputs": [],
  453. "source": [
  454. "class CrossEntropyLossForSquad(paddle.nn.Layer):\n",
  455. " def __init__(self):\n",
  456. " super(CrossEntropyLossForSquad, self).__init__()\n",
  457. "\n",
  458. " def forward(self, start_logits, end_logits, start_pos, end_pos):\n",
  459. " start_pos = paddle.unsqueeze(start_pos, axis=-1)\n",
  460. " end_pos = paddle.unsqueeze(end_pos, axis=-1)\n",
  461. " start_loss = paddle.nn.functional.softmax_with_cross_entropy(\n",
  462. " logits=start_logits, label=start_pos)\n",
  463. " start_loss = paddle.mean(start_loss)\n",
  464. " end_loss = paddle.nn.functional.softmax_with_cross_entropy(\n",
  465. " logits=end_logits, label=end_pos)\n",
  466. " end_loss = paddle.mean(end_loss)\n",
  467. "\n",
  468. " loss = (start_loss + end_loss) / 2\n",
  469. " return loss"
  470. ]
  471. },
  472. {
  473. "cell_type": "markdown",
  474. "metadata": {},
  475. "source": [
  476. "#### 3.2 定义模型\n",
  477. "\n",
  478. "模型的核心则是 `ErnieForQuestionAnswering` 的 `ernie-1.0-base-zh` 预训练模型,同时按照 `FastNLP` 的规定定义 `train_step` 和 `evaluate_step` 函数。这里 `evaluate_step` 函数并没有像文本分类那样直接返回该批次数据的评测结果,这一点我们将在下面为您讲解。"
  479. ]
  480. },
  481. {
  482. "cell_type": "code",
  483. "execution_count": 11,
  484. "metadata": {},
  485. "outputs": [
  486. {
  487. "name": "stderr",
  488. "output_type": "stream",
  489. "text": [
  490. "\u001b[32m[2022-06-27 19:00:15,825] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/ernie_v1_chn_base.pdparams\u001b[0m\n",
  491. "W0627 19:00:15.831080 21543 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 11.2, Runtime API Version: 11.2\n",
  492. "W0627 19:00:15.843276 21543 gpu_context.cc:306] device: 0, cuDNN Version: 8.1.\n"
  493. ]
  494. }
  495. ],
  496. "source": [
  497. "from paddlenlp.transformers import ErnieForQuestionAnswering\n",
  498. "\n",
  499. "class QAModel(paddle.nn.Layer):\n",
  500. " def __init__(self, model_checkpoint):\n",
  501. " super(QAModel, self).__init__()\n",
  502. " self.model = ErnieForQuestionAnswering.from_pretrained(model_checkpoint)\n",
  503. " self.loss_func = CrossEntropyLossForSquad()\n",
  504. "\n",
  505. " def forward(self, input_ids, token_type_ids):\n",
  506. " start_logits, end_logits = self.model(input_ids, token_type_ids)\n",
  507. " return start_logits, end_logits\n",
  508. "\n",
  509. " def train_step(self, input_ids, token_type_ids, start_pos, end_pos):\n",
  510. " start_logits, end_logits = self(input_ids, token_type_ids)\n",
  511. " loss = self.loss_func(start_logits, end_logits, start_pos, end_pos)\n",
  512. " return {\"loss\": loss}\n",
  513. "\n",
  514. " def evaluate_step(self, input_ids, token_type_ids):\n",
  515. " start_logits, end_logits = self(input_ids, token_type_ids)\n",
  516. " return {\"start_logits\": start_logits, \"end_logits\": end_logits}\n",
  517. "\n",
  518. "model = QAModel(MODEL_NAME)"
  519. ]
  520. },
  521. {
  522. "cell_type": "markdown",
  523. "metadata": {},
  524. "source": [
  525. "#### 3.3 自定义 Metric 进行数据的评估\n",
  526. "\n",
  527. "`paddlenlp` 为我们提供了评测 `SQuAD` 格式数据集的函数 `compute_prediction` 和 `squad_evaluate`:\n",
  528. "- `compute_prediction` 函数要求传入原数据 `examples` 、处理后的数据 `features` 和 `features` 对应的结果 `predictions`(一个包含所有数据 `start_logits` 和 `end_logits` 的元组)\n",
  529. "- `squad_evaluate` 要求传入原数据 `examples` 和预测结果 `all_predictions`(通常来自于 `compute_prediction`)\n",
  530. "\n",
  531. "在使用这两个函数的时候,我们需要向其中传入数据集,但显然根据 `fastNLP` 的设计,我们无法在 `evaluate_step` 里实现这一过程,并且 `FastNLP` 也并没有提供计算 `F1` 和 `EM` 的 `Metric`,故我们需要自己定义用于评测的 `Metric`。\n",
  532. "\n",
  533. "在初始化之外,一个 `Metric` 还需要实现三个函数:\n",
  534. "\n",
  535. "1. `reset` - 该函数会在验证数据集的迭代之前被调用,用于清空数据;在我们自定义的 `Metric` 中,我们需要将 `all_start_logits` 和 `all_end_logits` 清空,重新收集每个 `batch` 的结果。\n",
  536. "2. `update` - 该函数会在在每个 `batch` 得到结果后被调用,用于更新 `Metric` 的状态;它的参数即为 `evaluate_step` 返回的内容。我们在这里将得到的 `start_logits` 和 `end_logits` 收集起来。\n",
  537. "3. `get_metric` - 该函数会在数据集被迭代完毕后调用,用于计算评测的结果。现在我们有了整个验证集的 `all_start_logits` 和 `all_end_logits` ,将他们传入 `compute_predictions` 函数得到预测的结果,并继续使用 `squad_evaluate` 函数得到评测的结果。\n",
  538. " - 注:`suqad_evaluate` 函数会自己输出评测结果,为了不让其干扰 `FastNLP` 输出,这里我们使用 `contextlib.redirect_stdout(None)` 将函数的标准输出屏蔽掉。\n",
  539. "\n",
  540. "综上,`SquadEvaluateMetric` 实现的评估过程是:将验证集中所有数据的 `logits` 收集起来,然后统一传入 `compute_prediction` 和 `squad_evaluate` 中进行评估。值得一提的是,`paddlenlp.datasets.load_dataset` 返回的结果是一个 `MapDataset` 类型,其 `data` 成员为加载时的数据,`new_data` 为经过 `map` 函数处理后更新的数据,因此可以分别作为 `examples` 和 `features` 传入。"
  541. ]
  542. },
  543. {
  544. "cell_type": "code",
  545. "execution_count": 14,
  546. "metadata": {},
  547. "outputs": [],
  548. "source": [
  549. "from fastNLP.core import Metric\n",
  550. "from paddlenlp.metrics.squad import squad_evaluate, compute_prediction\n",
  551. "import contextlib\n",
  552. "\n",
  553. "class SquadEvaluateMetric(Metric):\n",
  554. " def __init__(self, examples, features, testing=False):\n",
  555. " super(SquadEvaluateMetric, self).__init__(\"paddle\", False)\n",
  556. " self.examples = examples\n",
  557. " self.features = features\n",
  558. " self.all_start_logits = []\n",
  559. " self.all_end_logits = []\n",
  560. " self.testing = testing\n",
  561. "\n",
  562. " def reset(self):\n",
  563. " self.all_start_logits = []\n",
  564. " self.all_end_logits = []\n",
  565. "\n",
  566. " def update(self, start_logits, end_logits):\n",
  567. " for start, end in zip(start_logits, end_logits):\n",
  568. " self.all_start_logits.append(start.numpy())\n",
  569. " self.all_end_logits.append(end.numpy())\n",
  570. "\n",
  571. " def get_metric(self):\n",
  572. " all_predictions, _, _ = compute_prediction(\n",
  573. " self.examples, self.features[:len(self.all_start_logits)],\n",
  574. " (self.all_start_logits, self.all_end_logits),\n",
  575. " False, 20, 30\n",
  576. " )\n",
  577. " with contextlib.redirect_stdout(None):\n",
  578. " result = squad_evaluate(\n",
  579. " examples=self.examples,\n",
  580. " preds=all_predictions,\n",
  581. " is_whitespace_splited=False\n",
  582. " )\n",
  583. "\n",
  584. " if self.testing:\n",
  585. " self.print_predictions(all_predictions)\n",
  586. " return result\n",
  587. "\n",
  588. " def print_predictions(self, preds):\n",
  589. " for i, data in enumerate(self.examples):\n",
  590. " if i >= 5:\n",
  591. " break\n",
  592. " print()\n",
  593. " print(\"原文:\", data[\"context\"])\n",
  594. " print(\"问题:\", data[\"question\"], \\\n",
  595. " \"答案:\", preds[data[\"id\"]], \\\n",
  596. " \"正确答案:\", data[\"answers\"][\"text\"])\n",
  597. "\n",
  598. "metric = SquadEvaluateMetric(\n",
  599. " val_dataloader.dataset.data,\n",
  600. " val_dataloader.dataset.new_data,\n",
  601. ")"
  602. ]
  603. },
  604. {
  605. "cell_type": "markdown",
  606. "metadata": {},
  607. "source": [
  608. "#### 3.4 训练\n",
  609. "\n",
  610. "至此所有的准备工作已经完成,可以使用 `Trainer` 进行训练了。学习率我们依旧采用线性预热策略 `LinearDecayWithWarmup`,优化器为 `AdamW`;回调模块我们选择 `LRSchedCallback` 更新学习率和 `LoadBestModelCallback` 监视评测结果的 `f1` 分数。初始化好 `Trainer` 之后,就将训练的过程交给 `FastNLP` 吧。"
  611. ]
  612. },
  613. {
  614. "cell_type": "code",
  615. "execution_count": 15,
  616. "metadata": {},
  617. "outputs": [
  618. {
  619. "data": {
  620. "text/html": [
  621. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[19:04:54] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#631\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">631</span></a>\n",
  622. "</pre>\n"
  623. ],
  624. "text/plain": [
  625. "\u001b[2;36m[19:04:54]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=367046;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96810;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n"
  626. ]
  627. },
  628. "metadata": {},
  629. "output_type": "display_data"
  630. },
  631. {
  632. "data": {
  633. "text/html": [
  634. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  635. ],
  636. "text/plain": []
  637. },
  638. "metadata": {},
  639. "output_type": "display_data"
  640. },
  641. {
  642. "data": {
  643. "text/html": [
  644. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  645. "</pre>\n"
  646. ],
  647. "text/plain": [
  648. "\n"
  649. ]
  650. },
  651. "metadata": {},
  652. "output_type": "display_data"
  653. },
  654. {
  655. "data": {
  656. "text/html": [
  657. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100</span> ----------------------------\n",
  658. "</pre>\n"
  659. ],
  660. "text/plain": [
  661. "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m100\u001b[0m ----------------------------\n"
  662. ]
  663. },
  664. "metadata": {},
  665. "output_type": "display_data"
  666. },
  667. {
  668. "data": {
  669. "text/html": [
  670. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  671. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">49.25899788285109</span>,\n",
  672. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">66.55559127349602</span>,\n",
  673. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>,\n",
  674. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">49.25899788285109</span>,\n",
  675. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">66.55559127349602</span>,\n",
  676. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>\n",
  677. "<span style=\"font-weight: bold\">}</span>\n",
  678. "</pre>\n"
  679. ],
  680. "text/plain": [
  681. "\u001b[1m{\u001b[0m\n",
  682. " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n",
  683. " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n",
  684. " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
  685. " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n",
  686. " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n",
  687. " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
  688. "\u001b[1m}\u001b[0m\n"
  689. ]
  690. },
  691. "metadata": {},
  692. "output_type": "display_data"
  693. },
  694. {
  695. "data": {
  696. "text/html": [
  697. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  698. "</pre>\n"
  699. ],
  700. "text/plain": [
  701. "\n"
  702. ]
  703. },
  704. "metadata": {},
  705. "output_type": "display_data"
  706. },
  707. {
  708. "data": {
  709. "text/html": [
  710. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">200</span> ----------------------------\n",
  711. "</pre>\n"
  712. ],
  713. "text/plain": [
  714. "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m200\u001b[0m ----------------------------\n"
  715. ]
  716. },
  717. "metadata": {},
  718. "output_type": "display_data"
  719. },
  720. {
  721. "data": {
  722. "text/html": [
  723. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  724. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">57.37473535638673</span>,\n",
  725. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">70.93036525200617</span>,\n",
  726. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>,\n",
  727. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">57.37473535638673</span>,\n",
  728. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">70.93036525200617</span>,\n",
  729. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>\n",
  730. "<span style=\"font-weight: bold\">}</span>\n",
  731. "</pre>\n"
  732. ],
  733. "text/plain": [
  734. "\u001b[1m{\u001b[0m\n",
  735. " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n",
  736. " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n",
  737. " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
  738. " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n",
  739. " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n",
  740. " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
  741. "\u001b[1m}\u001b[0m\n"
  742. ]
  743. },
  744. "metadata": {},
  745. "output_type": "display_data"
  746. },
  747. {
  748. "data": {
  749. "text/html": [
  750. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  751. "</pre>\n"
  752. ],
  753. "text/plain": [
  754. "\n"
  755. ]
  756. },
  757. "metadata": {},
  758. "output_type": "display_data"
  759. },
  760. {
  761. "data": {
  762. "text/html": [
  763. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">300</span> ----------------------------\n",
  764. "</pre>\n"
  765. ],
  766. "text/plain": [
  767. "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n"
  768. ]
  769. },
  770. "metadata": {},
  771. "output_type": "display_data"
  772. },
  773. {
  774. "data": {
  775. "text/html": [
  776. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  777. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">63.86732533521524</span>,\n",
  778. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">78.62546663568186</span>,\n",
  779. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>,\n",
  780. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">63.86732533521524</span>,\n",
  781. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">78.62546663568186</span>,\n",
  782. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>\n",
  783. "<span style=\"font-weight: bold\">}</span>\n",
  784. "</pre>\n"
  785. ],
  786. "text/plain": [
  787. "\u001b[1m{\u001b[0m\n",
  788. " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n",
  789. " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n",
  790. " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
  791. " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n",
  792. " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n",
  793. " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
  794. "\u001b[1m}\u001b[0m\n"
  795. ]
  796. },
  797. "metadata": {},
  798. "output_type": "display_data"
  799. },
  800. {
  801. "data": {
  802. "text/html": [
  803. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  804. "</pre>\n"
  805. ],
  806. "text/plain": [
  807. "\n"
  808. ]
  809. },
  810. "metadata": {},
  811. "output_type": "display_data"
  812. },
  813. {
  814. "data": {
  815. "text/html": [
  816. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">400</span> ----------------------------\n",
  817. "</pre>\n"
  818. ],
  819. "text/plain": [
  820. "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m400\u001b[0m ----------------------------\n"
  821. ]
  822. },
  823. "metadata": {},
  824. "output_type": "display_data"
  825. },
  826. {
  827. "data": {
  828. "text/html": [
  829. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  830. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">64.92589978828511</span>,\n",
  831. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">79.36746074079691</span>,\n",
  832. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>,\n",
  833. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">64.92589978828511</span>,\n",
  834. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">79.36746074079691</span>,\n",
  835. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>\n",
  836. "<span style=\"font-weight: bold\">}</span>\n",
  837. "</pre>\n"
  838. ],
  839. "text/plain": [
  840. "\u001b[1m{\u001b[0m\n",
  841. " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n",
  842. " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n",
  843. " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
  844. " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n",
  845. " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n",
  846. " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
  847. "\u001b[1m}\u001b[0m\n"
  848. ]
  849. },
  850. "metadata": {},
  851. "output_type": "display_data"
  852. },
  853. {
  854. "data": {
  855. "text/html": [
  856. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  857. "</pre>\n"
  858. ],
  859. "text/plain": [
  860. "\n"
  861. ]
  862. },
  863. "metadata": {},
  864. "output_type": "display_data"
  865. },
  866. {
  867. "data": {
  868. "text/html": [
  869. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">500</span> ----------------------------\n",
  870. "</pre>\n"
  871. ],
  872. "text/plain": [
  873. "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m500\u001b[0m ----------------------------\n"
  874. ]
  875. },
  876. "metadata": {},
  877. "output_type": "display_data"
  878. },
  879. {
  880. "data": {
  881. "text/html": [
  882. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  883. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">65.70218772053634</span>,\n",
  884. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">80.33295482054824</span>,\n",
  885. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>,\n",
  886. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">65.70218772053634</span>,\n",
  887. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">80.33295482054824</span>,\n",
  888. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>\n",
  889. "<span style=\"font-weight: bold\">}</span>\n",
  890. "</pre>\n"
  891. ],
  892. "text/plain": [
  893. "\u001b[1m{\u001b[0m\n",
  894. " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n",
  895. " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n",
  896. " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
  897. " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n",
  898. " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n",
  899. " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
  900. "\u001b[1m}\u001b[0m\n"
  901. ]
  902. },
  903. "metadata": {},
  904. "output_type": "display_data"
  905. },
  906. {
  907. "data": {
  908. "text/html": [
  909. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  910. "</pre>\n"
  911. ],
  912. "text/plain": [
  913. "\n"
  914. ]
  915. },
  916. "metadata": {},
  917. "output_type": "display_data"
  918. },
  919. {
  920. "data": {
  921. "text/html": [
  922. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">600</span> ----------------------------\n",
  923. "</pre>\n"
  924. ],
  925. "text/plain": [
  926. "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m600\u001b[0m ----------------------------\n"
  927. ]
  928. },
  929. "metadata": {},
  930. "output_type": "display_data"
  931. },
  932. {
  933. "data": {
  934. "text/html": [
  935. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  936. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">65.41990119971771</span>,\n",
  937. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">79.7483487059053</span>,\n",
  938. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>,\n",
  939. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">65.41990119971771</span>,\n",
  940. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">79.7483487059053</span>,\n",
  941. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>\n",
  942. "<span style=\"font-weight: bold\">}</span>\n",
  943. "</pre>\n"
  944. ],
  945. "text/plain": [
  946. "\u001b[1m{\u001b[0m\n",
  947. " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n",
  948. " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n",
  949. " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
  950. " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n",
  951. " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n",
  952. " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
  953. "\u001b[1m}\u001b[0m\n"
  954. ]
  955. },
  956. "metadata": {},
  957. "output_type": "display_data"
  958. },
  959. {
  960. "data": {
  961. "text/html": [
  962. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  963. "</pre>\n"
  964. ],
  965. "text/plain": [
  966. "\n"
  967. ]
  968. },
  969. "metadata": {},
  970. "output_type": "display_data"
  971. },
  972. {
  973. "data": {
  974. "text/html": [
  975. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">700</span> ----------------------------\n",
  976. "</pre>\n"
  977. ],
  978. "text/plain": [
  979. "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m700\u001b[0m ----------------------------\n"
  980. ]
  981. },
  982. "metadata": {},
  983. "output_type": "display_data"
  984. },
  985. {
  986. "data": {
  987. "text/html": [
  988. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  989. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">66.61961891319689</span>,\n",
  990. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">80.32432238994133</span>,\n",
  991. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>,\n",
  992. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">66.61961891319689</span>,\n",
  993. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">80.32432238994133</span>,\n",
  994. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>\n",
  995. "<span style=\"font-weight: bold\">}</span>\n",
  996. "</pre>\n"
  997. ],
  998. "text/plain": [
  999. "\u001b[1m{\u001b[0m\n",
  1000. " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n",
  1001. " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n",
  1002. " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
  1003. " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n",
  1004. " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n",
  1005. " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
  1006. "\u001b[1m}\u001b[0m\n"
  1007. ]
  1008. },
  1009. "metadata": {},
  1010. "output_type": "display_data"
  1011. },
  1012. {
  1013. "data": {
  1014. "text/html": [
  1015. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1016. "</pre>\n"
  1017. ],
  1018. "text/plain": [
  1019. "\n"
  1020. ]
  1021. },
  1022. "metadata": {},
  1023. "output_type": "display_data"
  1024. },
  1025. {
  1026. "data": {
  1027. "text/html": [
  1028. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">800</span> ----------------------------\n",
  1029. "</pre>\n"
  1030. ],
  1031. "text/plain": [
  1032. "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m800\u001b[0m ----------------------------\n"
  1033. ]
  1034. },
  1035. "metadata": {},
  1036. "output_type": "display_data"
  1037. },
  1038. {
  1039. "data": {
  1040. "text/html": [
  1041. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1042. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">65.84333098094567</span>,\n",
  1043. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">79.23169801265415</span>,\n",
  1044. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>,\n",
  1045. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_exact#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">65.84333098094567</span>,\n",
  1046. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_f1#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">79.23169801265415</span>,\n",
  1047. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"HasAns_total#squad\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>\n",
  1048. "<span style=\"font-weight: bold\">}</span>\n",
  1049. "</pre>\n"
  1050. ],
  1051. "text/plain": [
  1052. "\u001b[1m{\u001b[0m\n",
  1053. " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n",
  1054. " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n",
  1055. " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
  1056. " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n",
  1057. " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n",
  1058. " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
  1059. "\u001b[1m}\u001b[0m\n"
  1060. ]
  1061. },
  1062. "metadata": {},
  1063. "output_type": "display_data"
  1064. },
  1065. {
  1066. "data": {
  1067. "text/html": [
  1068. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  1069. ],
  1070. "text/plain": []
  1071. },
  1072. "metadata": {},
  1073. "output_type": "display_data"
  1074. },
  1075. {
  1076. "data": {
  1077. "text/html": [
  1078. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1079. "</pre>\n"
  1080. ],
  1081. "text/plain": [
  1082. "\n"
  1083. ]
  1084. },
  1085. "metadata": {},
  1086. "output_type": "display_data"
  1087. },
  1088. {
  1089. "data": {
  1090. "text/html": [
  1091. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[19:20:28] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Loading best model from fnlp-ernie-squad/ <a href=\"file://../fastNLP/core/callbacks/load_best_model_callback.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">load_best_model_callback.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/callbacks/load_best_model_callback.py#111\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">111</span></a>\n",
  1092. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2022</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">06</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">27</span>-19_00_15_388554/best_so_far <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  1093. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> with f1#squad: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">80.33295482054824</span><span style=\"color: #808000; text-decoration-color: #808000\">...</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  1094. "</pre>\n"
  1095. ],
  1096. "text/plain": [
  1097. "\u001b[2;36m[19:20:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie-squad/ \u001b]8;id=163935;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=31503;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n",
  1098. "\u001b[2;36m \u001b[0m \u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_00_15_388554/best_so_far \u001b[2m \u001b[0m\n",
  1099. "\u001b[2;36m \u001b[0m with f1#squad: \u001b[1;36m80.33295482054824\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n"
  1100. ]
  1101. },
  1102. "metadata": {},
  1103. "output_type": "display_data"
  1104. },
  1105. {
  1106. "data": {
  1107. "text/html": [
  1108. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Deleting fnlp-ernie-squad/<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2022</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">06</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">27</span>-19_0 <a href=\"file://../fastNLP/core/callbacks/load_best_model_callback.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">load_best_model_callback.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/callbacks/load_best_model_callback.py#131\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">131</span></a>\n",
  1109. "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> 0_15_388554/best_so_far<span style=\"color: #808000; text-decoration-color: #808000\">...</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
  1110. "</pre>\n"
  1111. ],
  1112. "text/plain": [
  1113. "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie-squad/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_0 \u001b]8;id=560859;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=573263;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n",
  1114. "\u001b[2;36m \u001b[0m 0_15_388554/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n"
  1115. ]
  1116. },
  1117. "metadata": {},
  1118. "output_type": "display_data"
  1119. }
  1120. ],
  1121. "source": [
  1122. "from fastNLP import Trainer, LRSchedCallback, LoadBestModelCallback\n",
  1123. "from paddlenlp.transformers import LinearDecayWithWarmup\n",
  1124. "\n",
  1125. "n_epochs = 1\n",
  1126. "num_training_steps = len(train_dataloader) * n_epochs\n",
  1127. "lr_scheduler = LinearDecayWithWarmup(3e-5, num_training_steps, 0.1)\n",
  1128. "optimizer = paddle.optimizer.AdamW(\n",
  1129. " learning_rate=lr_scheduler,\n",
  1130. " parameters=model.parameters(),\n",
  1131. ")\n",
  1132. "callbacks=[\n",
  1133. " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n",
  1134. " LoadBestModelCallback(\"f1#squad\", larger_better=True, save_folder=\"fnlp-ernie-squad\")\n",
  1135. "]\n",
  1136. "trainer = Trainer(\n",
  1137. " model=model,\n",
  1138. " train_dataloader=train_dataloader,\n",
  1139. " evaluate_dataloaders=val_dataloader,\n",
  1140. " device=1,\n",
  1141. " optimizers=optimizer,\n",
  1142. " n_epochs=n_epochs,\n",
  1143. " callbacks=callbacks,\n",
  1144. " evaluate_every=100,\n",
  1145. " metrics={\"squad\": metric},\n",
  1146. ")\n",
  1147. "trainer.run()"
  1148. ]
  1149. },
  1150. {
  1151. "cell_type": "markdown",
  1152. "metadata": {},
  1153. "source": [
  1154. "#### 3.5 测试\n",
  1155. "\n",
  1156. "最后,我们可以使用 `Evaluator` 查看我们训练的结果。我们在之前为 `SquadEvaluateMetric` 设置了 `testing` 参数来在测试阶段进行输出,可以看到,训练的结果还是比较不错的。"
  1157. ]
  1158. },
  1159. {
  1160. "cell_type": "code",
  1161. "execution_count": 16,
  1162. "metadata": {},
  1163. "outputs": [
  1164. {
  1165. "data": {
  1166. "text/html": [
  1167. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1168. "</pre>\n"
  1169. ],
  1170. "text/plain": [
  1171. "\n"
  1172. ]
  1173. },
  1174. "metadata": {},
  1175. "output_type": "display_data"
  1176. },
  1177. {
  1178. "data": {
  1179. "text/html": [
  1180. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n",
  1181. "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n",
  1182. "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n",
  1183. "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n",
  1184. "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n",
  1185. "行垫,油墨外露容易脱落。 \n",
  1186. "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n",
  1187. "</pre>\n"
  1188. ],
  1189. "text/plain": [
  1190. "原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n",
  1191. "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n",
  1192. "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n",
  1193. "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n",
  1194. "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n",
  1195. "行垫,油墨外露容易脱落。 \n",
  1196. "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n"
  1197. ]
  1198. },
  1199. "metadata": {},
  1200. "output_type": "display_data"
  1201. },
  1202. {
  1203. "data": {
  1204. "text/html": [
  1205. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n",
  1206. "</pre>\n"
  1207. ],
  1208. "text/plain": [
  1209. "问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n"
  1210. ]
  1211. },
  1212. "metadata": {},
  1213. "output_type": "display_data"
  1214. },
  1215. {
  1216. "data": {
  1217. "text/html": [
  1218. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1219. "</pre>\n"
  1220. ],
  1221. "text/plain": [
  1222. "\n"
  1223. ]
  1224. },
  1225. "metadata": {},
  1226. "output_type": "display_data"
  1227. },
  1228. {
  1229. "data": {
  1230. "text/html": [
  1231. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n",
  1232. "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n",
  1233. "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n",
  1234. "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n",
  1235. "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n",
  1236. "10厘米。\n",
  1237. "</pre>\n"
  1238. ],
  1239. "text/plain": [
  1240. "原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n",
  1241. "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n",
  1242. "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n",
  1243. "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n",
  1244. "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n",
  1245. "10厘米。\n"
  1246. ]
  1247. },
  1248. "metadata": {},
  1249. "output_type": "display_data"
  1250. },
  1251. {
  1252. "data": {
  1253. "text/html": [
  1254. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n",
  1255. "</pre>\n"
  1256. ],
  1257. "text/plain": [
  1258. "问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n"
  1259. ]
  1260. },
  1261. "metadata": {},
  1262. "output_type": "display_data"
  1263. },
  1264. {
  1265. "data": {
  1266. "text/html": [
  1267. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1268. "</pre>\n"
  1269. ],
  1270. "text/plain": [
  1271. "\n"
  1272. ]
  1273. },
  1274. "metadata": {},
  1275. "output_type": "display_data"
  1276. },
  1277. {
  1278. "data": {
  1279. "text/html": [
  1280. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n",
  1281. "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n",
  1282. "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n",
  1283. "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n",
  1284. "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n",
  1285. "</pre>\n"
  1286. ],
  1287. "text/plain": [
  1288. "原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n",
  1289. "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n",
  1290. "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n",
  1291. "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n",
  1292. "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n"
  1293. ]
  1294. },
  1295. "metadata": {},
  1296. "output_type": "display_data"
  1297. },
  1298. {
  1299. "data": {
  1300. "text/html": [
  1301. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n",
  1302. "</pre>\n"
  1303. ],
  1304. "text/plain": [
  1305. "问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n"
  1306. ]
  1307. },
  1308. "metadata": {},
  1309. "output_type": "display_data"
  1310. },
  1311. {
  1312. "data": {
  1313. "text/html": [
  1314. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1315. "</pre>\n"
  1316. ],
  1317. "text/plain": [
  1318. "\n"
  1319. ]
  1320. },
  1321. "metadata": {},
  1322. "output_type": "display_data"
  1323. },
  1324. {
  1325. "data": {
  1326. "text/html": [
  1327. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n",
  1328. "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n",
  1329. "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n",
  1330. "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n",
  1331. "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n",
  1332. "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n",
  1333. "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n",
  1334. "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n",
  1335. "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n",
  1336. "</pre>\n"
  1337. ],
  1338. "text/plain": [
  1339. "原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n",
  1340. "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n",
  1341. "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n",
  1342. "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n",
  1343. "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n",
  1344. "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n",
  1345. "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n",
  1346. "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n",
  1347. "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n"
  1348. ]
  1349. },
  1350. "metadata": {},
  1351. "output_type": "display_data"
  1352. },
  1353. {
  1354. "data": {
  1355. "text/html": [
  1356. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n",
  1357. "</pre>\n"
  1358. ],
  1359. "text/plain": [
  1360. "问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n"
  1361. ]
  1362. },
  1363. "metadata": {},
  1364. "output_type": "display_data"
  1365. },
  1366. {
  1367. "data": {
  1368. "text/html": [
  1369. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1370. "</pre>\n"
  1371. ],
  1372. "text/plain": [
  1373. "\n"
  1374. ]
  1375. },
  1376. "metadata": {},
  1377. "output_type": "display_data"
  1378. },
  1379. {
  1380. "data": {
  1381. "text/html": [
  1382. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n",
  1383. "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n",
  1384. "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n",
  1385. "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n",
  1386. "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n",
  1387. ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n",
  1388. "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n",
  1389. "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n",
  1390. "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n",
  1391. "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n",
  1392. "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n",
  1393. "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n",
  1394. "</pre>\n"
  1395. ],
  1396. "text/plain": [
  1397. "原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n",
  1398. "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n",
  1399. "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n",
  1400. "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n",
  1401. "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n",
  1402. ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n",
  1403. "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n",
  1404. "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n",
  1405. "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n",
  1406. "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n",
  1407. "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n",
  1408. "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n"
  1409. ]
  1410. },
  1411. "metadata": {},
  1412. "output_type": "display_data"
  1413. },
  1414. {
  1415. "data": {
  1416. "text/html": [
  1417. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n",
  1418. "</pre>\n"
  1419. ],
  1420. "text/plain": [
  1421. "问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n"
  1422. ]
  1423. },
  1424. "metadata": {},
  1425. "output_type": "display_data"
  1426. },
  1427. {
  1428. "data": {
  1429. "text/html": [
  1430. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  1431. ],
  1432. "text/plain": []
  1433. },
  1434. "metadata": {},
  1435. "output_type": "display_data"
  1436. },
  1437. {
  1438. "data": {
  1439. "text/html": [
  1440. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1441. " <span style=\"color: #008000; text-decoration-color: #008000\">'exact#squad'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">65.70218772053634</span>,\n",
  1442. " <span style=\"color: #008000; text-decoration-color: #008000\">'f1#squad'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">80.33295482054824</span>,\n",
  1443. " <span style=\"color: #008000; text-decoration-color: #008000\">'total#squad'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>,\n",
  1444. " <span style=\"color: #008000; text-decoration-color: #008000\">'HasAns_exact#squad'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">65.70218772053634</span>,\n",
  1445. " <span style=\"color: #008000; text-decoration-color: #008000\">'HasAns_f1#squad'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">80.33295482054824</span>,\n",
  1446. " <span style=\"color: #008000; text-decoration-color: #008000\">'HasAns_total#squad'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1417</span>\n",
  1447. "<span style=\"font-weight: bold\">}</span>\n",
  1448. "</pre>\n"
  1449. ],
  1450. "text/plain": [
  1451. "\u001b[1m{\u001b[0m\n",
  1452. " \u001b[32m'exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n",
  1453. " \u001b[32m'f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n",
  1454. " \u001b[32m'total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m,\n",
  1455. " \u001b[32m'HasAns_exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n",
  1456. " \u001b[32m'HasAns_f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n",
  1457. " \u001b[32m'HasAns_total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m\n",
  1458. "\u001b[1m}\u001b[0m\n"
  1459. ]
  1460. },
  1461. "metadata": {},
  1462. "output_type": "display_data"
  1463. }
  1464. ],
  1465. "source": [
  1466. "from fastNLP import Evaluator\n",
  1467. "evaluator = Evaluator(\n",
  1468. " model=model,\n",
  1469. " dataloaders=val_dataloader,\n",
  1470. " device=1,\n",
  1471. " metrics={\n",
  1472. " \"squad\": SquadEvaluateMetric(\n",
  1473. " val_dataloader.dataset.data,\n",
  1474. " val_dataloader.dataset.new_data,\n",
  1475. " testing=True,\n",
  1476. " ),\n",
  1477. " },\n",
  1478. ")\n",
  1479. "result = evaluator.run()"
  1480. ]
  1481. }
  1482. ],
  1483. "metadata": {
  1484. "kernelspec": {
  1485. "display_name": "Python 3.7.13 ('fnlp-paddle')",
  1486. "language": "python",
  1487. "name": "python3"
  1488. },
  1489. "language_info": {
  1490. "codemirror_mode": {
  1491. "name": "ipython",
  1492. "version": 3
  1493. },
  1494. "file_extension": ".py",
  1495. "mimetype": "text/x-python",
  1496. "name": "python",
  1497. "nbconvert_exporter": "python",
  1498. "pygments_lexer": "ipython3",
  1499. "version": "3.7.13"
  1500. },
  1501. "orig_nbformat": 4,
  1502. "vscode": {
  1503. "interpreter": {
  1504. "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720"
  1505. }
  1506. }
  1507. },
  1508. "nbformat": 4,
  1509. "nbformat_minor": 2
  1510. }