You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_5.ipynb 83 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "fdd7ff16",
  6. "metadata": {},
  7. "source": [
  8. "# T5. fastNLP 中的预定义模型\n",
  9. "\n",
  10. "  1   fastNLP 中 modules 的介绍\n",
  11. " \n",
  12. "    1.1   modules 模块、models 模块 简介\n",
  13. "\n",
  14. "    1.2   示例一:modules 实现 LSTM 分类\n",
  15. "\n",
  16. "  2   fastNLP 中 models 的介绍\n",
  17. " \n",
  18. "    2.1   示例一:models 实现 CNN 分类\n",
  19. "\n",
  20. "    2.3   示例二:models 实现 BiLSTM 标注"
  21. ]
  22. },
  23. {
  24. "cell_type": "markdown",
  25. "id": "d3d65d53",
  26. "metadata": {},
  27. "source": [
  28. "## 1. fastNLP 中 modules 模块的介绍\n",
  29. "\n",
  30. "### 1.1 modules 模块、models 模块 简介\n",
  31. "\n",
  32. "在`fastNLP 0.8`中,**`modules.torch`路径下定义了一些基于`pytorch`实现的基础模块**\n",
  33. "\n",
  34. "    包括长短期记忆网络`LSTM`、条件随机场`CRF`、`transformer`的编解码器模块等,详见下表\n",
  35. "\n",
  36. "| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n",
  37. "|:--|:--|:--|\n",
  38. "| `LSTM` | 轻量封装`pytorch`的`LSTM` | `/modules/torch/encoder/lstm.py` |\n",
  39. "| `Seq2SeqEncoder` | 序列变换编码器,基类 | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
  40. "| `LSTMSeq2SeqEncoder` | 序列变换编码器,基于`LSTM` | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
  41. "| `TransformerSeq2SeqEncoder` | 序列变换编码器,基于`transformer` | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
  42. "| `StarTransformer` | `Star-Transformer`的编码器部分 | `/modules/torch/encoder/star_transformer.py` |\n",
  43. "| `VarRNN` | 实现`Variational Dropout RNN` | `/modules/torch/encoder/variational_rnn.py` |\n",
  44. "| `VarLSTM` | 实现`Variational Dropout LSTM` | `/modules/torch/encoder/variational_rnn.py` |\n",
  45. "| `VarGRU` | 实现`Variational Dropout GRU` | `/modules/torch/encoder/variational_rnn.py` |\n",
  46. "| `ConditionalRandomField` | 条件随机场模型 | `/modules/torch/decoder/crf.py` |\n",
  47. "| `Seq2SeqDecoder` | 序列变换解码器,基类 | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
  48. "| `LSTMSeq2SeqDecoder` | 序列变换解码器,基于`LSTM` | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
  49. "| `TransformerSeq2SeqDecoder` | 序列变换解码器,基于`transformer` | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
  50. "| `SequenceGenerator` | 序列生成,封装`Seq2SeqDecoder` | `/models/torch/sequence_labeling.py` |\n",
  51. "| `TimestepDropout` | 在每个`timestamp`上`dropout` | `/modules/torch/dropout.py` |"
  52. ]
  53. },
  54. {
  55. "cell_type": "markdown",
  56. "id": "89ffcf07",
  57. "metadata": {},
  58. "source": [
  59. "&emsp; **`models.torch`路径下定义了一些基于`pytorch`、`modules`实现的预定义模型** \n",
  60. "\n",
  61. "&emsp; &emsp; 例如基于`CNN`的分类模型、基于`BiLSTM+CRF`的标注模型、基于[双仿射注意力机制](https://arxiv.org/pdf/1611.01734.pdf)的分析模型\n",
  62. "\n",
  63. "&emsp; &emsp; 基于`modules.torch`中的`LSTM`/`transformer`编/解码器模块的序列变换/生成模型,详见下表\n",
  64. "\n",
  65. "| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n",
  66. "|:--|:--|:--|\n",
  67. "| `BiaffineParser` | 句法分析模型,基于双仿射注意力 | `/models/torch/biaffine_parser.py` |\n",
  68. "| `CNNText` | 文本分类模型,基于`CNN` | `/models/torch/cnn_text_classification.py` |\n",
  69. "| `Seq2SeqModel` | 序列变换,基类`encoder+decoder` | `/models/torch/seq2seq_model.py` |\n",
  70. "| `LSTMSeq2SeqModel` | 序列变换,基于`LSTM` | `/models/torch/seq2seq_model.py` |\n",
  71. "| `TransformerSeq2SeqModel` | 序列变换,基于`transformer` | `/models/torch/seq2seq_model.py` |\n",
  72. "| `SequenceGeneratorModel` | 封装`Seq2SeqModel`,结合`SequenceGenerator` | `/models/torch/seq2seq_generator.py` |\n",
  73. "| `SeqLabeling` | 标注模型,基类`LSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n",
  74. "| `BiLSTMCRF` | 标注模型,`BiLSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n",
  75. "| `AdvSeqLabel` | 标注模型,`LN+BiLSTM*2+LN+FC+CRF` | `/models/torch/sequence_labeling.py` |"
  76. ]
  77. },
  78. {
  79. "cell_type": "markdown",
  80. "id": "61318354",
  81. "metadata": {},
  82. "source": [
  83. "上述`fastNLP`模块,不仅**为入门级用户提供了简单易用的工具**,以解决各种`NLP`任务,或复现相关论文\n",
  84. "\n",
  85. "&emsp; 同时**也为专业研究人员提供了便捷可操作的接口**,封装部分代码的同时,也能指定参数修改细节\n",
  86. "\n",
  87. "&emsp; 在接下来的`tutorial`中,我们将通过`SST-2`分类和`CoNLL-2003`标注,展示相关模型使用\n",
  88. "\n",
  89. "注一:**`SST`**,**单句情感分类**数据集,包含电影评论和对应情感极性,1 对应正面情感,0 对应负面情感\n",
  90. "\n",
  91. "&emsp; 数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n",
  92. "\n",
  93. "注二:**`CoNLL-2003`**,**文本语法标注**数据集,包含语句和对应的词性标签`pos_tags`(名动形数量代)\n",
  94. "\n",
  95. "&emsp; 语法结构标签`chunk_tags`(主谓宾定状补)、命名实体标签`ner_tags`(人名、组织名、地名、时间等)\n",
  96. "\n",
  97. "&emsp; 数据集包括三部分:训练集 14041 条,验证集 3250 条,测试集 3453 条,更多参考[原始论文](https://aclanthology.org/W03-0419.pdf)"
  98. ]
  99. },
  100. {
  101. "cell_type": "markdown",
  102. "id": "2a36bbe4",
  103. "metadata": {},
  104. "source": [
  105. "### 1.2 示例一:modules 实现 LSTM 分类"
  106. ]
  107. },
  108. {
  109. "cell_type": "code",
  110. "execution_count": 1,
  111. "id": "40e66b21",
  112. "metadata": {},
  113. "outputs": [],
  114. "source": [
  115. "# import sys\n",
  116. "# sys.path.append('..')\n",
  117. "\n",
  118. "# from fastNLP.io import SST2Pipe # 没有 SST2Pipe 会运行很长时间,并且还会报错\n",
  119. "\n",
  120. "# databundle = SST2Pipe(tokenizer='raw').process_from_file()\n",
  121. "\n",
  122. "# dataset = databundle.get_dataset('train')[:6000]\n",
  123. "\n",
  124. "# dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n",
  125. "# progress_bar=\"tqdm\")\n",
  126. "# dataset.delete_field('sentence')\n",
  127. "# dataset.delete_field('label')\n",
  128. "# dataset.delete_field('idx')\n",
  129. "\n",
  130. "# from fastNLP import Vocabulary\n",
  131. "\n",
  132. "# vocab = Vocabulary()\n",
  133. "# vocab.from_dataset(dataset, field_name='words')\n",
  134. "# vocab.index_dataset(dataset, field_name='words')\n",
  135. "\n",
  136. "# train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
  137. ]
  138. },
  139. {
  140. "cell_type": "code",
  141. "execution_count": 2,
  142. "id": "50960476",
  143. "metadata": {},
  144. "outputs": [],
  145. "source": [
  146. "# from fastNLP import prepare_torch_dataloader\n",
  147. "\n",
  148. "# train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
  149. "# evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
  150. ]
  151. },
  152. {
  153. "cell_type": "code",
  154. "execution_count": 3,
  155. "id": "0b25b25c",
  156. "metadata": {},
  157. "outputs": [],
  158. "source": [
  159. "# import torch\n",
  160. "# import torch.nn as nn\n",
  161. "\n",
  162. "# from fastNLP.modules.torch import LSTM, MLP # 没有 MLP\n",
  163. "# from fastNLP import Embedding, CrossEntropyLoss\n",
  164. "\n",
  165. "\n",
  166. "# class ClsByModules(nn.Module):\n",
  167. "# def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
  168. "# nn.Module.__init__(self)\n",
  169. "\n",
  170. "# self.embedding = Embedding((vocab_size, embedding_dim))\n",
  171. "# self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n",
  172. "# self.mlp = MLP([hidden_dim * 2, output_dim], dropout=dropout)\n",
  173. " \n",
  174. "# self.loss_fn = CrossEntropyLoss()\n",
  175. "\n",
  176. "# def forward(self, words):\n",
  177. "# output = self.embedding(words)\n",
  178. "# output, (hidden, cell) = self.lstm(output)\n",
  179. "# output = self.mlp(torch.cat((hidden[-1], hidden[-2]), dim=1))\n",
  180. "# return output\n",
  181. " \n",
  182. "# def train_step(self, words, target):\n",
  183. "# pred = self(words)\n",
  184. "# return {\"loss\": self.loss_fn(pred, target)}\n",
  185. "\n",
  186. "# def evaluate_step(self, words, target):\n",
  187. "# pred = self(words)\n",
  188. "# pred = torch.max(pred, dim=-1)[1]\n",
  189. "# return {\"pred\": pred, \"target\": target}"
  190. ]
  191. },
  192. {
  193. "cell_type": "code",
  194. "execution_count": 4,
  195. "id": "9dbbf50d",
  196. "metadata": {},
  197. "outputs": [],
  198. "source": [
  199. "# model = ClsByModules(vocab_size=len(vocabulary), embedding_dim=100, output_dim=2)\n",
  200. "\n",
  201. "# from torch.optim import AdamW\n",
  202. "\n",
  203. "# optimizers = AdamW(params=model.parameters(), lr=5e-5)"
  204. ]
  205. },
  206. {
  207. "cell_type": "code",
  208. "execution_count": 5,
  209. "id": "7a93432f",
  210. "metadata": {},
  211. "outputs": [],
  212. "source": [
  213. "# from fastNLP import Trainer, Accuracy\n",
  214. "\n",
  215. "# trainer = Trainer(\n",
  216. "# model=model,\n",
  217. "# driver='torch',\n",
  218. "# device=0, # 'cuda'\n",
  219. "# n_epochs=10,\n",
  220. "# optimizers=optimizers,\n",
  221. "# train_dataloader=train_dataloader,\n",
  222. "# evaluate_dataloaders=evaluate_dataloader,\n",
  223. "# metrics={'acc': Accuracy()}\n",
  224. "# )"
  225. ]
  226. },
  227. {
  228. "cell_type": "code",
  229. "execution_count": 6,
  230. "id": "31102e0f",
  231. "metadata": {},
  232. "outputs": [],
  233. "source": [
  234. "# trainer.run(num_eval_batch_per_dl=10)"
  235. ]
  236. },
  237. {
  238. "cell_type": "code",
  239. "execution_count": 7,
  240. "id": "8bc4bfb2",
  241. "metadata": {},
  242. "outputs": [],
  243. "source": [
  244. "# trainer.evaluator.run()"
  245. ]
  246. },
  247. {
  248. "cell_type": "markdown",
  249. "id": "d9443213",
  250. "metadata": {},
  251. "source": [
  252. "## 2. fastNLP 中 models 模块的介绍\n",
  253. "\n",
  254. "### 2.1 示例一:models 实现 CNN 分类\n",
  255. "\n",
  256. "&emsp; 本示例使用`fastNLP 0.8`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n",
  257. "\n",
  258. "模型使用方面,如上所述,这里使用**基于卷积神经网络`CNN`的预定义文本分类模型`CNNText`**,结构如下所示\n",
  259. "\n",
  260. "&emsp; 首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n",
  261. "\n",
  262. "&emsp; &emsp; **感受野为`1`、`3`、`5`的卷积算子变换至`30`维、`40`维、`50`维的卷积特征**,再将三者拼接\n",
  263. "\n",
  264. "&emsp; 最终再次通过`dropout`层、线性变换层,映射至二元的输出值,对应两个分类结果上的几率`logits`\n",
  265. "\n",
  266. "```\n",
  267. "CNNText(\n",
  268. " (embed): Embedding(\n",
  269. " (embed): Embedding(5194, 100)\n",
  270. " (dropout): Dropout(p=0.0, inplace=False)\n",
  271. " )\n",
  272. " (conv_pool): ConvMaxpool(\n",
  273. " (convs): ModuleList(\n",
  274. " (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n",
  275. " (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
  276. " (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
  277. " )\n",
  278. " )\n",
  279. " (dropout): Dropout(p=0.1, inplace=False)\n",
  280. " (fc): Linear(in_features=120, out_features=2, bias=True)\n",
  281. ")\n",
  282. "```\n",
  283. "\n",
  284. "数据使用方面,此处**使用`datasets`模块中的`load_dataset`函数**,以如下形式,指定`SST-2`数据集自动加载\n",
  285. "\n",
  286. "&emsp; 首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下"
  287. ]
  288. },
  289. {
  290. "cell_type": "code",
  291. "execution_count": 8,
  292. "id": "1aa5cf6d",
  293. "metadata": {},
  294. "outputs": [
  295. {
  296. "name": "stderr",
  297. "output_type": "stream",
  298. "text": [
  299. "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n",
  300. "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
  301. ]
  302. },
  303. {
  304. "data": {
  305. "application/vnd.jupyter.widget-view+json": {
  306. "model_id": "70cde65067c64fdba1d5e798e2b8d631",
  307. "version_major": 2,
  308. "version_minor": 0
  309. },
  310. "text/plain": [
  311. " 0%| | 0/3 [00:00<?, ?it/s]"
  312. ]
  313. },
  314. "metadata": {},
  315. "output_type": "display_data"
  316. }
  317. ],
  318. "source": [
  319. "from datasets import load_dataset\n",
  320. "\n",
  321. "sst2data = load_dataset('glue', 'sst2')"
  322. ]
  323. },
  324. {
  325. "cell_type": "markdown",
  326. "id": "c476abe7",
  327. "metadata": {},
  328. "source": [
  329. "紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n",
  330. "\n",
  331. "&emsp; **使用`apply_more`函数、`Vocabulary`模块的`from_/index_dataset`函数预处理数据**\n",
  332. "\n",
  333. "&emsp; &emsp; 并结合`delete_field`函数删除字段调整格式,`split`函数划分测试集和验证集\n",
  334. "\n",
  335. "&emsp; **仅保留`'words'`字段表示输入文本单词序号序列、`'target'`字段表示文本对应预测输出结果**\n",
  336. "\n",
  337. "&emsp; &emsp; 两者**对应到`CNNText`中`train_step`函数和`evaluate_step`函数的签名/输入参数**"
  338. ]
  339. },
  340. {
  341. "cell_type": "code",
  342. "execution_count": 9,
  343. "id": "357ea748",
  344. "metadata": {},
  345. "outputs": [
  346. {
  347. "data": {
  348. "text/html": [
  349. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  350. "</pre>\n"
  351. ],
  352. "text/plain": [
  353. "\n"
  354. ]
  355. },
  356. "metadata": {},
  357. "output_type": "display_data"
  358. },
  359. {
  360. "data": {
  361. "application/vnd.jupyter.widget-view+json": {
  362. "model_id": "",
  363. "version_major": 2,
  364. "version_minor": 0
  365. },
  366. "text/plain": [
  367. "Processing: 0%| | 0/6000 [00:00<?, ?it/s]"
  368. ]
  369. },
  370. "metadata": {},
  371. "output_type": "display_data"
  372. }
  373. ],
  374. "source": [
  375. "import sys\n",
  376. "sys.path.append('..')\n",
  377. "\n",
  378. "from fastNLP import DataSet\n",
  379. "\n",
  380. "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
  381. "\n",
  382. "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n",
  383. " progress_bar=\"tqdm\")\n",
  384. "dataset.delete_field('sentence')\n",
  385. "dataset.delete_field('label')\n",
  386. "dataset.delete_field('idx')\n",
  387. "\n",
  388. "from fastNLP import Vocabulary\n",
  389. "\n",
  390. "vocab = Vocabulary()\n",
  391. "vocab.from_dataset(dataset, field_name='words')\n",
  392. "vocab.index_dataset(dataset, field_name='words')\n",
  393. "\n",
  394. "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
  395. ]
  396. },
  397. {
  398. "cell_type": "markdown",
  399. "id": "96380c67",
  400. "metadata": {},
  401. "source": [
  402. "然后,使用`tutorial-3`中的知识,**通过`prepare_torch_dataloader`处理数据集得到`dataloader`**"
  403. ]
  404. },
  405. {
  406. "cell_type": "code",
  407. "execution_count": 10,
  408. "id": "b9dd1273",
  409. "metadata": {},
  410. "outputs": [],
  411. "source": [
  412. "from fastNLP import prepare_torch_dataloader\n",
  413. "\n",
  414. "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
  415. "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
  416. ]
  417. },
  418. {
  419. "cell_type": "markdown",
  420. "id": "96941b63",
  421. "metadata": {},
  422. "source": [
  423. "接着,**从`fastNLP.models.torch`路径下导入`CNNText`**,初始化`CNNText`实例以及`optimizer`实例\n",
  424. "\n",
  425. "&emsp; 注意:初始化`CNNText`时,**二元组参数`embed`、分类数量`num_classes`是必须传入的**,其中\n",
  426. "\n",
  427. "&emsp; &emsp; **`embed`表示嵌入层的嵌入抽取矩阵大小**,因此第二个元素对应的是默认隐藏层维度 `100`维"
  428. ]
  429. },
  430. {
  431. "cell_type": "code",
  432. "execution_count": 11,
  433. "id": "f6e76e2e",
  434. "metadata": {},
  435. "outputs": [],
  436. "source": [
  437. "from fastNLP.models.torch import CNNText\n",
  438. "\n",
  439. "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
  440. "\n",
  441. "from torch.optim import AdamW\n",
  442. "\n",
  443. "optimizers = AdamW(params=model.parameters(), lr=5e-4)"
  444. ]
  445. },
  446. {
  447. "cell_type": "markdown",
  448. "id": "0cc5ca10",
  449. "metadata": {},
  450. "source": [
  451. "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练"
  452. ]
  453. },
  454. {
  455. "cell_type": "code",
  456. "execution_count": 12,
  457. "id": "50a13ee5",
  458. "metadata": {},
  459. "outputs": [],
  460. "source": [
  461. "from fastNLP import Trainer, Accuracy\n",
  462. "\n",
  463. "trainer = Trainer(\n",
  464. " model=model,\n",
  465. " driver='torch',\n",
  466. " device=0, # 'cuda'\n",
  467. " n_epochs=10,\n",
  468. " optimizers=optimizers,\n",
  469. " train_dataloader=train_dataloader,\n",
  470. " evaluate_dataloaders=evaluate_dataloader,\n",
  471. " metrics={'acc': Accuracy()}\n",
  472. ")"
  473. ]
  474. },
  475. {
  476. "cell_type": "code",
  477. "execution_count": 13,
  478. "id": "28903a7d",
  479. "metadata": {},
  480. "outputs": [
  481. {
  482. "data": {
  483. "text/html": [
  484. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[17:45:59] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#592\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">592</span></a>\n",
  485. "</pre>\n"
  486. ],
  487. "text/plain": [
  488. "\u001b[2;36m[17:45:59]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=147745;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=708408;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
  489. ]
  490. },
  491. "metadata": {},
  492. "output_type": "display_data"
  493. },
  494. {
  495. "data": {
  496. "application/vnd.jupyter.widget-view+json": {
  497. "model_id": "",
  498. "version_major": 2,
  499. "version_minor": 0
  500. },
  501. "text/plain": [
  502. "Output()"
  503. ]
  504. },
  505. "metadata": {},
  506. "output_type": "display_data"
  507. },
  508. {
  509. "data": {
  510. "text/html": [
  511. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  512. ],
  513. "text/plain": []
  514. },
  515. "metadata": {},
  516. "output_type": "display_data"
  517. },
  518. {
  519. "data": {
  520. "application/vnd.jupyter.widget-view+json": {
  521. "model_id": "",
  522. "version_major": 2,
  523. "version_minor": 0
  524. },
  525. "text/plain": [
  526. "Output()"
  527. ]
  528. },
  529. "metadata": {},
  530. "output_type": "display_data"
  531. },
  532. {
  533. "data": {
  534. "text/html": [
  535. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  536. "</pre>\n"
  537. ],
  538. "text/plain": [
  539. "\n"
  540. ]
  541. },
  542. "metadata": {},
  543. "output_type": "display_data"
  544. },
  545. {
  546. "data": {
  547. "text/html": [
  548. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  549. "</pre>\n"
  550. ],
  551. "text/plain": [
  552. "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  553. ]
  554. },
  555. "metadata": {},
  556. "output_type": "display_data"
  557. },
  558. {
  559. "data": {
  560. "text/html": [
  561. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  562. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.575</span>,\n",
  563. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
  564. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">92.0</span>\n",
  565. "<span style=\"font-weight: bold\">}</span>\n",
  566. "</pre>\n"
  567. ],
  568. "text/plain": [
  569. "\u001b[1m{\u001b[0m\n",
  570. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.575\u001b[0m,\n",
  571. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
  572. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m92.0\u001b[0m\n",
  573. "\u001b[1m}\u001b[0m\n"
  574. ]
  575. },
  576. "metadata": {},
  577. "output_type": "display_data"
  578. },
  579. {
  580. "data": {
  581. "text/html": [
  582. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  583. "</pre>\n"
  584. ],
  585. "text/plain": [
  586. "\n"
  587. ]
  588. },
  589. "metadata": {},
  590. "output_type": "display_data"
  591. },
  592. {
  593. "data": {
  594. "text/html": [
  595. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  596. "</pre>\n"
  597. ],
  598. "text/plain": [
  599. "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  600. ]
  601. },
  602. "metadata": {},
  603. "output_type": "display_data"
  604. },
  605. {
  606. "data": {
  607. "text/html": [
  608. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  609. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.75625</span>,\n",
  610. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
  611. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">121.0</span>\n",
  612. "<span style=\"font-weight: bold\">}</span>\n",
  613. "</pre>\n"
  614. ],
  615. "text/plain": [
  616. "\u001b[1m{\u001b[0m\n",
  617. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n",
  618. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
  619. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121.0\u001b[0m\n",
  620. "\u001b[1m}\u001b[0m\n"
  621. ]
  622. },
  623. "metadata": {},
  624. "output_type": "display_data"
  625. },
  626. {
  627. "data": {
  628. "text/html": [
  629. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  630. "</pre>\n"
  631. ],
  632. "text/plain": [
  633. "\n"
  634. ]
  635. },
  636. "metadata": {},
  637. "output_type": "display_data"
  638. },
  639. {
  640. "data": {
  641. "text/html": [
  642. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  643. "</pre>\n"
  644. ],
  645. "text/plain": [
  646. "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  647. ]
  648. },
  649. "metadata": {},
  650. "output_type": "display_data"
  651. },
  652. {
  653. "data": {
  654. "text/html": [
  655. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  656. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.78125</span>,\n",
  657. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
  658. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">125.0</span>\n",
  659. "<span style=\"font-weight: bold\">}</span>\n",
  660. "</pre>\n"
  661. ],
  662. "text/plain": [
  663. "\u001b[1m{\u001b[0m\n",
  664. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n",
  665. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
  666. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n",
  667. "\u001b[1m}\u001b[0m\n"
  668. ]
  669. },
  670. "metadata": {},
  671. "output_type": "display_data"
  672. },
  673. {
  674. "data": {
  675. "text/html": [
  676. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  677. "</pre>\n"
  678. ],
  679. "text/plain": [
  680. "\n"
  681. ]
  682. },
  683. "metadata": {},
  684. "output_type": "display_data"
  685. },
  686. {
  687. "data": {
  688. "text/html": [
  689. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  690. "</pre>\n"
  691. ],
  692. "text/plain": [
  693. "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  694. ]
  695. },
  696. "metadata": {},
  697. "output_type": "display_data"
  698. },
  699. {
  700. "data": {
  701. "text/html": [
  702. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  703. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8</span>,\n",
  704. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
  705. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">128.0</span>\n",
  706. "<span style=\"font-weight: bold\">}</span>\n",
  707. "</pre>\n"
  708. ],
  709. "text/plain": [
  710. "\u001b[1m{\u001b[0m\n",
  711. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n",
  712. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
  713. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n",
  714. "\u001b[1m}\u001b[0m\n"
  715. ]
  716. },
  717. "metadata": {},
  718. "output_type": "display_data"
  719. },
  720. {
  721. "data": {
  722. "text/html": [
  723. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  724. "</pre>\n"
  725. ],
  726. "text/plain": [
  727. "\n"
  728. ]
  729. },
  730. "metadata": {},
  731. "output_type": "display_data"
  732. },
  733. {
  734. "data": {
  735. "text/html": [
  736. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  737. "</pre>\n"
  738. ],
  739. "text/plain": [
  740. "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  741. ]
  742. },
  743. "metadata": {},
  744. "output_type": "display_data"
  745. },
  746. {
  747. "data": {
  748. "text/html": [
  749. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  750. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.79375</span>,\n",
  751. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
  752. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">127.0</span>\n",
  753. "<span style=\"font-weight: bold\">}</span>\n",
  754. "</pre>\n"
  755. ],
  756. "text/plain": [
  757. "\u001b[1m{\u001b[0m\n",
  758. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n",
  759. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
  760. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n",
  761. "\u001b[1m}\u001b[0m\n"
  762. ]
  763. },
  764. "metadata": {},
  765. "output_type": "display_data"
  766. },
  767. {
  768. "data": {
  769. "text/html": [
  770. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  771. "</pre>\n"
  772. ],
  773. "text/plain": [
  774. "\n"
  775. ]
  776. },
  777. "metadata": {},
  778. "output_type": "display_data"
  779. },
  780. {
  781. "data": {
  782. "text/html": [
  783. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  784. "</pre>\n"
  785. ],
  786. "text/plain": [
  787. "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  788. ]
  789. },
  790. "metadata": {},
  791. "output_type": "display_data"
  792. },
  793. {
  794. "data": {
  795. "text/html": [
  796. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  797. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>,\n",
  798. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
  799. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">129.0</span>\n",
  800. "<span style=\"font-weight: bold\">}</span>\n",
  801. "</pre>\n"
  802. ],
  803. "text/plain": [
  804. "\u001b[1m{\u001b[0m\n",
  805. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n",
  806. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
  807. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n",
  808. "\u001b[1m}\u001b[0m\n"
  809. ]
  810. },
  811. "metadata": {},
  812. "output_type": "display_data"
  813. },
  814. {
  815. "data": {
  816. "text/html": [
  817. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  818. "</pre>\n"
  819. ],
  820. "text/plain": [
  821. "\n"
  822. ]
  823. },
  824. "metadata": {},
  825. "output_type": "display_data"
  826. },
  827. {
  828. "data": {
  829. "text/html": [
  830. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  831. "</pre>\n"
  832. ],
  833. "text/plain": [
  834. "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  835. ]
  836. },
  837. "metadata": {},
  838. "output_type": "display_data"
  839. },
  840. {
  841. "data": {
  842. "text/html": [
  843. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  844. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.81875</span>,\n",
  845. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
  846. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">131.0</span>\n",
  847. "<span style=\"font-weight: bold\">}</span>\n",
  848. "</pre>\n"
  849. ],
  850. "text/plain": [
  851. "\u001b[1m{\u001b[0m\n",
  852. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n",
  853. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
  854. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n",
  855. "\u001b[1m}\u001b[0m\n"
  856. ]
  857. },
  858. "metadata": {},
  859. "output_type": "display_data"
  860. },
  861. {
  862. "data": {
  863. "text/html": [
  864. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  865. "</pre>\n"
  866. ],
  867. "text/plain": [
  868. "\n"
  869. ]
  870. },
  871. "metadata": {},
  872. "output_type": "display_data"
  873. },
  874. {
  875. "data": {
  876. "text/html": [
  877. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  878. "</pre>\n"
  879. ],
  880. "text/plain": [
  881. "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  882. ]
  883. },
  884. "metadata": {},
  885. "output_type": "display_data"
  886. },
  887. {
  888. "data": {
  889. "text/html": [
  890. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  891. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.825</span>,\n",
  892. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
  893. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">132.0</span>\n",
  894. "<span style=\"font-weight: bold\">}</span>\n",
  895. "</pre>\n"
  896. ],
  897. "text/plain": [
  898. "\u001b[1m{\u001b[0m\n",
  899. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.825\u001b[0m,\n",
  900. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
  901. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m132.0\u001b[0m\n",
  902. "\u001b[1m}\u001b[0m\n"
  903. ]
  904. },
  905. "metadata": {},
  906. "output_type": "display_data"
  907. },
  908. {
  909. "data": {
  910. "text/html": [
  911. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  912. "</pre>\n"
  913. ],
  914. "text/plain": [
  915. "\n"
  916. ]
  917. },
  918. "metadata": {},
  919. "output_type": "display_data"
  920. },
  921. {
  922. "data": {
  923. "text/html": [
  924. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  925. "</pre>\n"
  926. ],
  927. "text/plain": [
  928. "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  929. ]
  930. },
  931. "metadata": {},
  932. "output_type": "display_data"
  933. },
  934. {
  935. "data": {
  936. "text/html": [
  937. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  938. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.81875</span>,\n",
  939. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
  940. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">131.0</span>\n",
  941. "<span style=\"font-weight: bold\">}</span>\n",
  942. "</pre>\n"
  943. ],
  944. "text/plain": [
  945. "\u001b[1m{\u001b[0m\n",
  946. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n",
  947. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
  948. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n",
  949. "\u001b[1m}\u001b[0m\n"
  950. ]
  951. },
  952. "metadata": {},
  953. "output_type": "display_data"
  954. },
  955. {
  956. "data": {
  957. "text/html": [
  958. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  959. "</pre>\n"
  960. ],
  961. "text/plain": [
  962. "\n"
  963. ]
  964. },
  965. "metadata": {},
  966. "output_type": "display_data"
  967. },
  968. {
  969. "data": {
  970. "text/html": [
  971. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  972. "</pre>\n"
  973. ],
  974. "text/plain": [
  975. "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  976. ]
  977. },
  978. "metadata": {},
  979. "output_type": "display_data"
  980. },
  981. {
  982. "data": {
  983. "text/html": [
  984. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  985. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.81875</span>,\n",
  986. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
  987. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">131.0</span>\n",
  988. "<span style=\"font-weight: bold\">}</span>\n",
  989. "</pre>\n"
  990. ],
  991. "text/plain": [
  992. "\u001b[1m{\u001b[0m\n",
  993. " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n",
  994. " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
  995. " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n",
  996. "\u001b[1m}\u001b[0m\n"
  997. ]
  998. },
  999. "metadata": {},
  1000. "output_type": "display_data"
  1001. },
  1002. {
  1003. "data": {
  1004. "text/html": [
  1005. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  1006. ],
  1007. "text/plain": []
  1008. },
  1009. "metadata": {},
  1010. "output_type": "display_data"
  1011. },
  1012. {
  1013. "data": {
  1014. "text/html": [
  1015. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1016. "</pre>\n"
  1017. ],
  1018. "text/plain": [
  1019. "\n"
  1020. ]
  1021. },
  1022. "metadata": {},
  1023. "output_type": "display_data"
  1024. }
  1025. ],
  1026. "source": [
  1027. "trainer.run(num_eval_batch_per_dl=10)"
  1028. ]
  1029. },
  1030. {
  1031. "cell_type": "code",
  1032. "execution_count": 14,
  1033. "id": "f47a6a35",
  1034. "metadata": {},
  1035. "outputs": [
  1036. {
  1037. "data": {
  1038. "application/vnd.jupyter.widget-view+json": {
  1039. "model_id": "",
  1040. "version_major": 2,
  1041. "version_minor": 0
  1042. },
  1043. "text/plain": [
  1044. "Output()"
  1045. ]
  1046. },
  1047. "metadata": {},
  1048. "output_type": "display_data"
  1049. },
  1050. {
  1051. "data": {
  1052. "text/html": [
  1053. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  1054. ],
  1055. "text/plain": []
  1056. },
  1057. "metadata": {},
  1058. "output_type": "display_data"
  1059. },
  1060. {
  1061. "data": {
  1062. "text/plain": [
  1063. "{'acc#acc': 0.79, 'total#acc': 900.0, 'correct#acc': 711.0}"
  1064. ]
  1065. },
  1066. "execution_count": 14,
  1067. "metadata": {},
  1068. "output_type": "execute_result"
  1069. }
  1070. ],
  1071. "source": [
  1072. "trainer.evaluator.run()"
  1073. ]
  1074. },
  1075. {
  1076. "cell_type": "markdown",
  1077. "id": "7c811257",
  1078. "metadata": {},
  1079. "source": [
  1080. "&emsp; 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间"
  1081. ]
  1082. },
  1083. {
  1084. "cell_type": "code",
  1085. "execution_count": 15,
  1086. "id": "c1a2e2ca",
  1087. "metadata": {},
  1088. "outputs": [
  1089. {
  1090. "data": {
  1091. "text/plain": [
  1092. "342"
  1093. ]
  1094. },
  1095. "execution_count": 15,
  1096. "metadata": {},
  1097. "output_type": "execute_result"
  1098. }
  1099. ],
  1100. "source": [
  1101. "import gc\n",
  1102. "\n",
  1103. "del model\n",
  1104. "del trainer\n",
  1105. "del dataset\n",
  1106. "del sst2data\n",
  1107. "\n",
  1108. "gc.collect()"
  1109. ]
  1110. },
  1111. {
  1112. "cell_type": "markdown",
  1113. "id": "6aec2a19",
  1114. "metadata": {},
  1115. "source": [
  1116. "### 2.2 示例二:models 实现 BiLSTM 标注\n",
  1117. "\n",
  1118. "&emsp; 通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n",
  1119. "\n",
  1120. "&emsp; &emsp; 针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n",
  1121. "\n",
  1122. "&emsp; 避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n",
  1123. "\n",
  1124. "模型使用方面,如上所述,这里使用**基于双向`LSTM`+条件随机场`CRF`的标注模型`BiLSTMCRF`**,结构如下所示\n",
  1125. "\n",
  1126. "&emsp; 其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n",
  1127. "\n",
  1128. "```\n",
  1129. "BiLSTMCRF(\n",
  1130. " (embed): Embedding(7590, 100)\n",
  1131. " (lstm): LSTM(\n",
  1132. " (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n",
  1133. " )\n",
  1134. " (dropout): Dropout(p=0.1, inplace=False)\n",
  1135. " (fc): Linear(in_features=200, out_features=9, bias=True)\n",
  1136. " (crf): ConditionalRandomField()\n",
  1137. ")\n",
  1138. "```\n",
  1139. "\n",
  1140. "数据使用方面,此处仍然**使用`datasets`模块中的`load_dataset`函数**,以如下形式,加载`CoNLL-2003`数据集\n",
  1141. "\n",
  1142. "&emsp; 首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下"
  1143. ]
  1144. },
  1145. {
  1146. "cell_type": "code",
  1147. "execution_count": 16,
  1148. "id": "03e66686",
  1149. "metadata": {},
  1150. "outputs": [
  1151. {
  1152. "name": "stderr",
  1153. "output_type": "stream",
  1154. "text": [
  1155. "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n"
  1156. ]
  1157. },
  1158. {
  1159. "data": {
  1160. "application/vnd.jupyter.widget-view+json": {
  1161. "model_id": "3ec9e0ce9a054339a2453420c2c9f28b",
  1162. "version_major": 2,
  1163. "version_minor": 0
  1164. },
  1165. "text/plain": [
  1166. " 0%| | 0/3 [00:00<?, ?it/s]"
  1167. ]
  1168. },
  1169. "metadata": {},
  1170. "output_type": "display_data"
  1171. }
  1172. ],
  1173. "source": [
  1174. "from datasets import load_dataset\n",
  1175. "\n",
  1176. "ner2data = load_dataset('conll2003', 'conll2003')"
  1177. ]
  1178. },
  1179. {
  1180. "cell_type": "markdown",
  1181. "id": "fc505631",
  1182. "metadata": {},
  1183. "source": [
  1184. "紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n",
  1185. "\n",
  1186. "&emsp; 完成数据集格式调整、文本序列化等操作;此处**需要`'words'`、`'seq_len'`、`'target'`三个字段**\n",
  1187. "\n",
  1188. "此外,**需要定义`NER`标签到标签序号的映射**(**词汇表`label_vocab`**),数据集中标签已经完成了序号映射\n",
  1189. "\n",
  1190. "&emsp; 所以需要人工定义**`9`个标签对应之前的`9`个分类目标**;数据集说明中规定,`'O'`表示其他标签\n",
  1191. "\n",
  1192. "&emsp; **后缀`'-PER'`、`'-ORG'`、`'-LOC'`、`'-MISC'`对应人名、组织名、地名、时间等其他命名**\n",
  1193. "\n",
  1194. "&emsp; **前缀`'B-'`表示起始标签、`'I-'`表示终止标签**;例如,`'B-PER'`表示人名实体的起始标签"
  1195. ]
  1196. },
  1197. {
  1198. "cell_type": "code",
  1199. "execution_count": 17,
  1200. "id": "1f88cad4",
  1201. "metadata": {},
  1202. "outputs": [
  1203. {
  1204. "data": {
  1205. "application/vnd.jupyter.widget-view+json": {
  1206. "model_id": "",
  1207. "version_major": 2,
  1208. "version_minor": 0
  1209. },
  1210. "text/plain": [
  1211. "Processing: 0%| | 0/4000 [00:00<?, ?it/s]"
  1212. ]
  1213. },
  1214. "metadata": {},
  1215. "output_type": "display_data"
  1216. }
  1217. ],
  1218. "source": [
  1219. "import sys\n",
  1220. "sys.path.append('..')\n",
  1221. "\n",
  1222. "from fastNLP import DataSet\n",
  1223. "\n",
  1224. "dataset = DataSet.from_pandas(ner2data['train'].to_pandas())[:4000]\n",
  1225. "\n",
  1226. "dataset.apply_more(lambda ins:{'words': ins['tokens'], 'seq_len': len(ins['tokens']), 'target': ins['ner_tags']}, \n",
  1227. " progress_bar=\"tqdm\")\n",
  1228. "dataset.delete_field('tokens')\n",
  1229. "dataset.delete_field('ner_tags')\n",
  1230. "dataset.delete_field('pos_tags')\n",
  1231. "dataset.delete_field('chunk_tags')\n",
  1232. "dataset.delete_field('id')\n",
  1233. "\n",
  1234. "from fastNLP import Vocabulary\n",
  1235. "\n",
  1236. "token_vocab = Vocabulary()\n",
  1237. "token_vocab.from_dataset(dataset, field_name='words')\n",
  1238. "token_vocab.index_dataset(dataset, field_name='words')\n",
  1239. "label_vocab = Vocabulary(padding=None, unknown=None)\n",
  1240. "label_vocab.add_word_lst(['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'])\n",
  1241. "\n",
  1242. "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
  1243. ]
  1244. },
  1245. {
  1246. "cell_type": "markdown",
  1247. "id": "d9889427",
  1248. "metadata": {},
  1249. "source": [
  1250. "然后,同样使用`tutorial-3`中的知识,通过`prepare_torch_dataloader`处理数据集得到`dataloader`"
  1251. ]
  1252. },
  1253. {
  1254. "cell_type": "code",
  1255. "execution_count": 18,
  1256. "id": "7802a072",
  1257. "metadata": {},
  1258. "outputs": [],
  1259. "source": [
  1260. "from fastNLP import prepare_torch_dataloader\n",
  1261. "\n",
  1262. "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
  1263. "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
  1264. ]
  1265. },
  1266. {
  1267. "cell_type": "markdown",
  1268. "id": "2bc7831b",
  1269. "metadata": {},
  1270. "source": [
  1271. "接着,**从`fastNLP.models.torch`路径下导入`BiLSTMCRF`**,初始化`BiLSTMCRF`实例和优化器\n",
  1272. "\n",
  1273. "&emsp; 注意:初始化`BiLSTMCRF`时,和`CNNText`相同,**参数`embed`、`num_classes`是必须传入的**\n",
  1274. "\n",
  1275. "&emsp; &emsp; 隐藏层维度`hidden_size`默认`100`维,调整`150`维;退学概率默认`0.1`,调整`0.2`"
  1276. ]
  1277. },
  1278. {
  1279. "cell_type": "code",
  1280. "execution_count": 19,
  1281. "id": "4e12c09f",
  1282. "metadata": {},
  1283. "outputs": [],
  1284. "source": [
  1285. "from fastNLP.models.torch import BiLSTMCRF\n",
  1286. "\n",
  1287. "model = BiLSTMCRF(embed=(len(token_vocab), 150), num_classes=len(label_vocab), \n",
  1288. " num_layers=1, hidden_size=150, dropout=0.2)\n",
  1289. "\n",
  1290. "from torch.optim import AdamW\n",
  1291. "\n",
  1292. "optimizers = AdamW(params=model.parameters(), lr=1e-3)"
  1293. ]
  1294. },
  1295. {
  1296. "cell_type": "markdown",
  1297. "id": "bf30608f",
  1298. "metadata": {},
  1299. "source": [
  1300. "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练\n",
  1301. "\n",
  1302. "&emsp; 参考`tutorial-4`中的内容,**使用`SpanFPreRecMetric`作为`NER`的评价标准**\n",
  1303. "\n",
  1304. "&emsp; 同时,**初始化时需要添加`vocabulary`形式的标签与序号之间的映射`tag_vocab`**"
  1305. ]
  1306. },
  1307. {
  1308. "cell_type": "code",
  1309. "execution_count": 20,
  1310. "id": "cbd6c205",
  1311. "metadata": {},
  1312. "outputs": [],
  1313. "source": [
  1314. "from fastNLP import Trainer, SpanFPreRecMetric\n",
  1315. "\n",
  1316. "trainer = Trainer(\n",
  1317. " model=model,\n",
  1318. " driver='torch',\n",
  1319. " device=0, # 'cuda'\n",
  1320. " n_epochs=10,\n",
  1321. " optimizers=optimizers,\n",
  1322. " train_dataloader=train_dataloader,\n",
  1323. " evaluate_dataloaders=evaluate_dataloader,\n",
  1324. " metrics={'F1': SpanFPreRecMetric(tag_vocab=label_vocab)}\n",
  1325. ")"
  1326. ]
  1327. },
  1328. {
  1329. "cell_type": "code",
  1330. "execution_count": 21,
  1331. "id": "0f8eff34",
  1332. "metadata": {},
  1333. "outputs": [
  1334. {
  1335. "data": {
  1336. "text/html": [
  1337. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[17:49:16] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#592\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">592</span></a>\n",
  1338. "</pre>\n"
  1339. ],
  1340. "text/plain": [
  1341. "\u001b[2;36m[17:49:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=766109;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=787419;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
  1342. ]
  1343. },
  1344. "metadata": {},
  1345. "output_type": "display_data"
  1346. },
  1347. {
  1348. "data": {
  1349. "application/vnd.jupyter.widget-view+json": {
  1350. "model_id": "",
  1351. "version_major": 2,
  1352. "version_minor": 0
  1353. },
  1354. "text/plain": [
  1355. "Output()"
  1356. ]
  1357. },
  1358. "metadata": {},
  1359. "output_type": "display_data"
  1360. },
  1361. {
  1362. "data": {
  1363. "text/html": [
  1364. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  1365. ],
  1366. "text/plain": []
  1367. },
  1368. "metadata": {},
  1369. "output_type": "display_data"
  1370. },
  1371. {
  1372. "data": {
  1373. "application/vnd.jupyter.widget-view+json": {
  1374. "model_id": "",
  1375. "version_major": 2,
  1376. "version_minor": 0
  1377. },
  1378. "text/plain": [
  1379. "Output()"
  1380. ]
  1381. },
  1382. "metadata": {},
  1383. "output_type": "display_data"
  1384. },
  1385. {
  1386. "data": {
  1387. "text/html": [
  1388. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1389. "</pre>\n"
  1390. ],
  1391. "text/plain": [
  1392. "\n"
  1393. ]
  1394. },
  1395. "metadata": {},
  1396. "output_type": "display_data"
  1397. },
  1398. {
  1399. "data": {
  1400. "text/html": [
  1401. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1402. "</pre>\n"
  1403. ],
  1404. "text/plain": [
  1405. "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1406. ]
  1407. },
  1408. "metadata": {},
  1409. "output_type": "display_data"
  1410. },
  1411. {
  1412. "data": {
  1413. "text/html": [
  1414. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1415. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.220374</span>,\n",
  1416. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.25</span>,\n",
  1417. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.197026</span>\n",
  1418. "<span style=\"font-weight: bold\">}</span>\n",
  1419. "</pre>\n"
  1420. ],
  1421. "text/plain": [
  1422. "\u001b[1m{\u001b[0m\n",
  1423. " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.220374\u001b[0m,\n",
  1424. " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.25\u001b[0m,\n",
  1425. " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.197026\u001b[0m\n",
  1426. "\u001b[1m}\u001b[0m\n"
  1427. ]
  1428. },
  1429. "metadata": {},
  1430. "output_type": "display_data"
  1431. },
  1432. {
  1433. "data": {
  1434. "text/html": [
  1435. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1436. "</pre>\n"
  1437. ],
  1438. "text/plain": [
  1439. "\n"
  1440. ]
  1441. },
  1442. "metadata": {},
  1443. "output_type": "display_data"
  1444. },
  1445. {
  1446. "data": {
  1447. "text/html": [
  1448. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1449. "</pre>\n"
  1450. ],
  1451. "text/plain": [
  1452. "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1453. ]
  1454. },
  1455. "metadata": {},
  1456. "output_type": "display_data"
  1457. },
  1458. {
  1459. "data": {
  1460. "text/html": [
  1461. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1462. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.442857</span>,\n",
  1463. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.426117</span>,\n",
  1464. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.460967</span>\n",
  1465. "<span style=\"font-weight: bold\">}</span>\n",
  1466. "</pre>\n"
  1467. ],
  1468. "text/plain": [
  1469. "\u001b[1m{\u001b[0m\n",
  1470. " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.442857\u001b[0m,\n",
  1471. " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.426117\u001b[0m,\n",
  1472. " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.460967\u001b[0m\n",
  1473. "\u001b[1m}\u001b[0m\n"
  1474. ]
  1475. },
  1476. "metadata": {},
  1477. "output_type": "display_data"
  1478. },
  1479. {
  1480. "data": {
  1481. "text/html": [
  1482. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1483. "</pre>\n"
  1484. ],
  1485. "text/plain": [
  1486. "\n"
  1487. ]
  1488. },
  1489. "metadata": {},
  1490. "output_type": "display_data"
  1491. },
  1492. {
  1493. "data": {
  1494. "text/html": [
  1495. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1496. "</pre>\n"
  1497. ],
  1498. "text/plain": [
  1499. "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1500. ]
  1501. },
  1502. "metadata": {},
  1503. "output_type": "display_data"
  1504. },
  1505. {
  1506. "data": {
  1507. "text/html": [
  1508. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1509. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.572954</span>,\n",
  1510. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.549488</span>,\n",
  1511. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.598513</span>\n",
  1512. "<span style=\"font-weight: bold\">}</span>\n",
  1513. "</pre>\n"
  1514. ],
  1515. "text/plain": [
  1516. "\u001b[1m{\u001b[0m\n",
  1517. " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.572954\u001b[0m,\n",
  1518. " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.549488\u001b[0m,\n",
  1519. " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.598513\u001b[0m\n",
  1520. "\u001b[1m}\u001b[0m\n"
  1521. ]
  1522. },
  1523. "metadata": {},
  1524. "output_type": "display_data"
  1525. },
  1526. {
  1527. "data": {
  1528. "text/html": [
  1529. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1530. "</pre>\n"
  1531. ],
  1532. "text/plain": [
  1533. "\n"
  1534. ]
  1535. },
  1536. "metadata": {},
  1537. "output_type": "display_data"
  1538. },
  1539. {
  1540. "data": {
  1541. "text/html": [
  1542. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1543. "</pre>\n"
  1544. ],
  1545. "text/plain": [
  1546. "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1547. ]
  1548. },
  1549. "metadata": {},
  1550. "output_type": "display_data"
  1551. },
  1552. {
  1553. "data": {
  1554. "text/html": [
  1555. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1556. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.665399</span>,\n",
  1557. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.680934</span>,\n",
  1558. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.650558</span>\n",
  1559. "<span style=\"font-weight: bold\">}</span>\n",
  1560. "</pre>\n"
  1561. ],
  1562. "text/plain": [
  1563. "\u001b[1m{\u001b[0m\n",
  1564. " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.665399\u001b[0m,\n",
  1565. " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.680934\u001b[0m,\n",
  1566. " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.650558\u001b[0m\n",
  1567. "\u001b[1m}\u001b[0m\n"
  1568. ]
  1569. },
  1570. "metadata": {},
  1571. "output_type": "display_data"
  1572. },
  1573. {
  1574. "data": {
  1575. "text/html": [
  1576. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1577. "</pre>\n"
  1578. ],
  1579. "text/plain": [
  1580. "\n"
  1581. ]
  1582. },
  1583. "metadata": {},
  1584. "output_type": "display_data"
  1585. },
  1586. {
  1587. "data": {
  1588. "text/html": [
  1589. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1590. "</pre>\n"
  1591. ],
  1592. "text/plain": [
  1593. "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1594. ]
  1595. },
  1596. "metadata": {},
  1597. "output_type": "display_data"
  1598. },
  1599. {
  1600. "data": {
  1601. "text/html": [
  1602. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1603. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.734694</span>,\n",
  1604. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.733333</span>,\n",
  1605. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.736059</span>\n",
  1606. "<span style=\"font-weight: bold\">}</span>\n",
  1607. "</pre>\n"
  1608. ],
  1609. "text/plain": [
  1610. "\u001b[1m{\u001b[0m\n",
  1611. " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.734694\u001b[0m,\n",
  1612. " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.733333\u001b[0m,\n",
  1613. " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.736059\u001b[0m\n",
  1614. "\u001b[1m}\u001b[0m\n"
  1615. ]
  1616. },
  1617. "metadata": {},
  1618. "output_type": "display_data"
  1619. },
  1620. {
  1621. "data": {
  1622. "text/html": [
  1623. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1624. "</pre>\n"
  1625. ],
  1626. "text/plain": [
  1627. "\n"
  1628. ]
  1629. },
  1630. "metadata": {},
  1631. "output_type": "display_data"
  1632. },
  1633. {
  1634. "data": {
  1635. "text/html": [
  1636. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1637. "</pre>\n"
  1638. ],
  1639. "text/plain": [
  1640. "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1641. ]
  1642. },
  1643. "metadata": {},
  1644. "output_type": "display_data"
  1645. },
  1646. {
  1647. "data": {
  1648. "text/html": [
  1649. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1650. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.742647</span>,\n",
  1651. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.734545</span>,\n",
  1652. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.750929</span>\n",
  1653. "<span style=\"font-weight: bold\">}</span>\n",
  1654. "</pre>\n"
  1655. ],
  1656. "text/plain": [
  1657. "\u001b[1m{\u001b[0m\n",
  1658. " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.742647\u001b[0m,\n",
  1659. " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.734545\u001b[0m,\n",
  1660. " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.750929\u001b[0m\n",
  1661. "\u001b[1m}\u001b[0m\n"
  1662. ]
  1663. },
  1664. "metadata": {},
  1665. "output_type": "display_data"
  1666. },
  1667. {
  1668. "data": {
  1669. "text/html": [
  1670. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1671. "</pre>\n"
  1672. ],
  1673. "text/plain": [
  1674. "\n"
  1675. ]
  1676. },
  1677. "metadata": {},
  1678. "output_type": "display_data"
  1679. },
  1680. {
  1681. "data": {
  1682. "text/html": [
  1683. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1684. "</pre>\n"
  1685. ],
  1686. "text/plain": [
  1687. "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1688. ]
  1689. },
  1690. "metadata": {},
  1691. "output_type": "display_data"
  1692. },
  1693. {
  1694. "data": {
  1695. "text/html": [
  1696. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1697. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.773585</span>,\n",
  1698. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.785441</span>,\n",
  1699. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.762082</span>\n",
  1700. "<span style=\"font-weight: bold\">}</span>\n",
  1701. "</pre>\n"
  1702. ],
  1703. "text/plain": [
  1704. "\u001b[1m{\u001b[0m\n",
  1705. " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.773585\u001b[0m,\n",
  1706. " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.785441\u001b[0m,\n",
  1707. " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.762082\u001b[0m\n",
  1708. "\u001b[1m}\u001b[0m\n"
  1709. ]
  1710. },
  1711. "metadata": {},
  1712. "output_type": "display_data"
  1713. },
  1714. {
  1715. "data": {
  1716. "text/html": [
  1717. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1718. "</pre>\n"
  1719. ],
  1720. "text/plain": [
  1721. "\n"
  1722. ]
  1723. },
  1724. "metadata": {},
  1725. "output_type": "display_data"
  1726. },
  1727. {
  1728. "data": {
  1729. "text/html": [
  1730. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1731. "</pre>\n"
  1732. ],
  1733. "text/plain": [
  1734. "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1735. ]
  1736. },
  1737. "metadata": {},
  1738. "output_type": "display_data"
  1739. },
  1740. {
  1741. "data": {
  1742. "text/html": [
  1743. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1744. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.770115</span>,\n",
  1745. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.794466</span>,\n",
  1746. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.747212</span>\n",
  1747. "<span style=\"font-weight: bold\">}</span>\n",
  1748. "</pre>\n"
  1749. ],
  1750. "text/plain": [
  1751. "\u001b[1m{\u001b[0m\n",
  1752. " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.770115\u001b[0m,\n",
  1753. " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.794466\u001b[0m,\n",
  1754. " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.747212\u001b[0m\n",
  1755. "\u001b[1m}\u001b[0m\n"
  1756. ]
  1757. },
  1758. "metadata": {},
  1759. "output_type": "display_data"
  1760. },
  1761. {
  1762. "data": {
  1763. "text/html": [
  1764. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1765. "</pre>\n"
  1766. ],
  1767. "text/plain": [
  1768. "\n"
  1769. ]
  1770. },
  1771. "metadata": {},
  1772. "output_type": "display_data"
  1773. },
  1774. {
  1775. "data": {
  1776. "text/html": [
  1777. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1778. "</pre>\n"
  1779. ],
  1780. "text/plain": [
  1781. "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1782. ]
  1783. },
  1784. "metadata": {},
  1785. "output_type": "display_data"
  1786. },
  1787. {
  1788. "data": {
  1789. "text/html": [
  1790. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1791. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.7603</span>,\n",
  1792. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.766038</span>,\n",
  1793. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.754647</span>\n",
  1794. "<span style=\"font-weight: bold\">}</span>\n",
  1795. "</pre>\n"
  1796. ],
  1797. "text/plain": [
  1798. "\u001b[1m{\u001b[0m\n",
  1799. " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.7603\u001b[0m,\n",
  1800. " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.766038\u001b[0m,\n",
  1801. " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.754647\u001b[0m\n",
  1802. "\u001b[1m}\u001b[0m\n"
  1803. ]
  1804. },
  1805. "metadata": {},
  1806. "output_type": "display_data"
  1807. },
  1808. {
  1809. "data": {
  1810. "text/html": [
  1811. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1812. "</pre>\n"
  1813. ],
  1814. "text/plain": [
  1815. "\n"
  1816. ]
  1817. },
  1818. "metadata": {},
  1819. "output_type": "display_data"
  1820. },
  1821. {
  1822. "data": {
  1823. "text/html": [
  1824. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
  1825. "</pre>\n"
  1826. ],
  1827. "text/plain": [
  1828. "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
  1829. ]
  1830. },
  1831. "metadata": {},
  1832. "output_type": "display_data"
  1833. },
  1834. {
  1835. "data": {
  1836. "text/html": [
  1837. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
  1838. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.743682</span>,\n",
  1839. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.722807</span>,\n",
  1840. " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.765799</span>\n",
  1841. "<span style=\"font-weight: bold\">}</span>\n",
  1842. "</pre>\n"
  1843. ],
  1844. "text/plain": [
  1845. "\u001b[1m{\u001b[0m\n",
  1846. " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.743682\u001b[0m,\n",
  1847. " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.722807\u001b[0m,\n",
  1848. " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.765799\u001b[0m\n",
  1849. "\u001b[1m}\u001b[0m\n"
  1850. ]
  1851. },
  1852. "metadata": {},
  1853. "output_type": "display_data"
  1854. },
  1855. {
  1856. "data": {
  1857. "text/html": [
  1858. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  1859. ],
  1860. "text/plain": []
  1861. },
  1862. "metadata": {},
  1863. "output_type": "display_data"
  1864. },
  1865. {
  1866. "data": {
  1867. "text/html": [
  1868. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
  1869. "</pre>\n"
  1870. ],
  1871. "text/plain": [
  1872. "\n"
  1873. ]
  1874. },
  1875. "metadata": {},
  1876. "output_type": "display_data"
  1877. }
  1878. ],
  1879. "source": [
  1880. "trainer.run(num_eval_batch_per_dl=10)"
  1881. ]
  1882. },
  1883. {
  1884. "cell_type": "code",
  1885. "execution_count": 22,
  1886. "id": "37871d6b",
  1887. "metadata": {},
  1888. "outputs": [
  1889. {
  1890. "data": {
  1891. "application/vnd.jupyter.widget-view+json": {
  1892. "model_id": "",
  1893. "version_major": 2,
  1894. "version_minor": 0
  1895. },
  1896. "text/plain": [
  1897. "Output()"
  1898. ]
  1899. },
  1900. "metadata": {},
  1901. "output_type": "display_data"
  1902. },
  1903. {
  1904. "data": {
  1905. "text/html": [
  1906. "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
  1907. ],
  1908. "text/plain": []
  1909. },
  1910. "metadata": {},
  1911. "output_type": "display_data"
  1912. },
  1913. {
  1914. "data": {
  1915. "text/plain": [
  1916. "{'f#F1': 0.75283, 'pre#F1': 0.727438, 'rec#F1': 0.780059}"
  1917. ]
  1918. },
  1919. "execution_count": 22,
  1920. "metadata": {},
  1921. "output_type": "execute_result"
  1922. }
  1923. ],
  1924. "source": [
  1925. "trainer.evaluator.run()"
  1926. ]
  1927. },
  1928. {
  1929. "cell_type": "code",
  1930. "execution_count": null,
  1931. "id": "96bae094",
  1932. "metadata": {},
  1933. "outputs": [],
  1934. "source": []
  1935. }
  1936. ],
  1937. "metadata": {
  1938. "kernelspec": {
  1939. "display_name": "Python 3 (ipykernel)",
  1940. "language": "python",
  1941. "name": "python3"
  1942. },
  1943. "language_info": {
  1944. "codemirror_mode": {
  1945. "name": "ipython",
  1946. "version": 3
  1947. },
  1948. "file_extension": ".py",
  1949. "mimetype": "text/x-python",
  1950. "name": "python",
  1951. "nbconvert_exporter": "python",
  1952. "pygments_lexer": "ipython3",
  1953. "version": "3.7.13"
  1954. }
  1955. },
  1956. "nbformat": 4,
  1957. "nbformat_minor": 5
  1958. }