You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastnlp_tutorial_4.ipynb 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "fdd7ff16",
  6. "metadata": {},
  7. "source": [
  8. "# T4. fastNLP 中的预定义模型\n",
  9. "\n",
  10. "  1   fastNLP 中 modules 的介绍\n",
  11. " \n",
  12. "    1.1   modules 模块、models 模块 简介\n",
  13. "\n",
  14. "    1.2   示例一:modules 实现 LSTM 分类\n",
  15. "\n",
  16. "  2   fastNLP 中 models 的介绍\n",
  17. " \n",
  18. "    2.1   示例一:models 实现 CNN 分类\n",
  19. "\n",
  20. "    2.3   示例二:models 实现 BiLSTM 标注"
  21. ]
  22. },
  23. {
  24. "cell_type": "markdown",
  25. "id": "d3d65d53",
  26. "metadata": {},
  27. "source": [
  28. "## 1. fastNLP 中 modules 模块的介绍\n",
  29. "\n",
  30. "### 1.1 modules 模块、models 模块 简介\n",
  31. "\n",
  32. "在`fastNLP 0.8`中,**`modules.torch`路径下定义了一些基于`pytorch`实现的基础模块**\n",
  33. "\n",
  34. "    包括长短期记忆网络`LSTM`、条件随机场`CRF`、`transformer`的编解码器模块等,详见下表\n",
  35. "\n",
  36. "| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n",
  37. "|:--|:--|:--|\n",
  38. "| `LSTM` | 轻量封装`pytorch`的`LSTM` | `/modules/torch/encoder/lstm.py` |\n",
  39. "| `Seq2SeqEncoder` | 序列变换编码器,基类 | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
  40. "| `LSTMSeq2SeqEncoder` | 序列变换编码器,基于`LSTM` | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
  41. "| `TransformerSeq2SeqEncoder` | 序列变换编码器,基于`transformer` | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
  42. "| `StarTransformer` | `Star-Transformer`的编码器部分 | `/modules/torch/encoder/star_transformer.py` |\n",
  43. "| `VarRNN` | 实现`Variational Dropout RNN` | `/modules/torch/encoder/variational_rnn.py` |\n",
  44. "| `VarLSTM` | 实现`Variational Dropout LSTM` | `/modules/torch/encoder/variational_rnn.py` |\n",
  45. "| `VarGRU` | 实现`Variational Dropout GRU` | `/modules/torch/encoder/variational_rnn.py` |\n",
  46. "| `ConditionalRandomField` | 条件随机场模型 | `/modules/torch/decoder/crf.py` |\n",
  47. "| `Seq2SeqDecoder` | 序列变换解码器,基类 | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
  48. "| `LSTMSeq2SeqDecoder` | 序列变换解码器,基于`LSTM` | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
  49. "| `TransformerSeq2SeqDecoder` | 序列变换解码器,基于`transformer` | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
  50. "| `SequenceGenerator` | 序列生成,封装`Seq2SeqDecoder` | `/models/torch/sequence_labeling.py` |\n",
  51. "| `TimestepDropout` | 在每个`timestamp`上`dropout` | `/modules/torch/dropout.py` |"
  52. ]
  53. },
  54. {
  55. "cell_type": "markdown",
  56. "id": "89ffcf07",
  57. "metadata": {},
  58. "source": [
  59. "&emsp; **`models.torch`路径下定义了一些基于`pytorch`、`modules`实现的预定义模型** \n",
  60. "\n",
  61. "&emsp; &emsp; 例如基于`CNN`的分类模型、基于`BiLSTM+CRF`的标注模型、基于[双仿射注意力机制](https://arxiv.org/pdf/1611.01734.pdf)的分析模型\n",
  62. "\n",
  63. "&emsp; &emsp; 基于`modules.torch`中的`LSTM`/`transformer`编/解码器模块的序列变换/生成模型,详见下表\n",
  64. "\n",
  65. "| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n",
  66. "|:--|:--|:--|\n",
  67. "| `BiaffineParser` | 句法分析模型,基于双仿射注意力 | `/models/torch/biaffine_parser.py` |\n",
  68. "| `CNNText` | 文本分类模型,基于`CNN` | `/models/torch/cnn_text_classification.py` |\n",
  69. "| `Seq2SeqModel` | 序列变换,基类`encoder+decoder` | `/models/torch/seq2seq_model.py` |\n",
  70. "| `LSTMSeq2SeqModel` | 序列变换,基于`LSTM` | `/models/torch/seq2seq_model.py` |\n",
  71. "| `TransformerSeq2SeqModel` | 序列变换,基于`transformer` | `/models/torch/seq2seq_model.py` |\n",
  72. "| `SequenceGeneratorModel` | 封装`Seq2SeqModel`,结合`SequenceGenerator` | `/models/torch/seq2seq_generator.py` |\n",
  73. "| `SeqLabeling` | 标注模型,基类`LSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n",
  74. "| `BiLSTMCRF` | 标注模型,`BiLSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n",
  75. "| `AdvSeqLabel` | 标注模型,`LN+BiLSTM*2+LN+FC+CRF` | `/models/torch/sequence_labeling.py` |"
  76. ]
  77. },
  78. {
  79. "cell_type": "markdown",
  80. "id": "61318354",
  81. "metadata": {},
  82. "source": [
  83. "上述`fastNLP`模块,不仅**为入门级用户提供了简单易用的工具**,以解决各种`NLP`任务,或复现相关论文\n",
  84. "\n",
  85. "&emsp; 同时**也为专业研究人员提供了便捷可操作的接口**,封装部分代码的同时,也能指定参数修改细节\n",
  86. "\n",
  87. "&emsp; 在接下来的`tutorial`中,我们将通过`SST-2`分类和`CoNLL-2003`标注,展示相关模型使用\n",
  88. "\n",
  89. "注一:**`SST`**,**单句情感分类**数据集,包含电影评论和对应情感极性,1 对应正面情感,0 对应负面情感\n",
  90. "\n",
  91. "&emsp; 数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n",
  92. "\n",
  93. "注二:**`CoNLL-2003`**,**文本语法标注**数据集,包含语句和对应的词性标签`pos_tags`(名动形数量代)\n",
  94. "\n",
  95. "&emsp; 语法结构标签`chunk_tags`(主谓宾定状补)、命名实体标签`ner_tags`(人名、组织名、地名、时间等)\n",
  96. "\n",
  97. "&emsp; 数据集包括三部分:训练集 14041 条,验证集 3250 条,测试集 3453 条,更多参考[原始论文](https://aclanthology.org/W03-0419.pdf)"
  98. ]
  99. },
  100. {
  101. "cell_type": "markdown",
  102. "id": "2a36bbe4",
  103. "metadata": {},
  104. "source": [
  105. "### 1.2 示例一:modules 实现 LSTM 分类"
  106. ]
  107. },
  108. {
  109. "cell_type": "code",
  110. "execution_count": null,
  111. "id": "40e66b21",
  112. "metadata": {},
  113. "outputs": [],
  114. "source": [
  115. "# import sys\n",
  116. "# sys.path.append('..')\n",
  117. "\n",
  118. "# from fastNLP.io import SST2Pipe # 没有 SST2Pipe 会运行很长时间,并且还会报错\n",
  119. "\n",
  120. "# databundle = SST2Pipe(tokenizer='raw').process_from_file()\n",
  121. "\n",
  122. "# dataset = databundle.get_dataset('train')[:6000]\n",
  123. "\n",
  124. "# dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n",
  125. "# progress_bar=\"tqdm\")\n",
  126. "# dataset.delete_field('sentence')\n",
  127. "# dataset.delete_field('label')\n",
  128. "# dataset.delete_field('idx')\n",
  129. "\n",
  130. "# from fastNLP import Vocabulary\n",
  131. "\n",
  132. "# vocab = Vocabulary()\n",
  133. "# vocab.from_dataset(dataset, field_name='words')\n",
  134. "# vocab.index_dataset(dataset, field_name='words')\n",
  135. "\n",
  136. "# train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
  137. ]
  138. },
  139. {
  140. "cell_type": "code",
  141. "execution_count": null,
  142. "id": "50960476",
  143. "metadata": {},
  144. "outputs": [],
  145. "source": [
  146. "# from fastNLP import prepare_torch_dataloader\n",
  147. "\n",
  148. "# train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
  149. "# evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
  150. ]
  151. },
  152. {
  153. "cell_type": "code",
  154. "execution_count": null,
  155. "id": "0b25b25c",
  156. "metadata": {},
  157. "outputs": [],
  158. "source": [
  159. "# import torch\n",
  160. "# import torch.nn as nn\n",
  161. "\n",
  162. "# from fastNLP.modules.torch import LSTM, MLP # 没有 MLP\n",
  163. "# from fastNLP import Embedding, CrossEntropyLoss\n",
  164. "\n",
  165. "\n",
  166. "# class ClsByModules(nn.Module):\n",
  167. "# def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
  168. "# nn.Module.__init__(self)\n",
  169. "\n",
  170. "# self.embedding = Embedding((vocab_size, embedding_dim))\n",
  171. "# self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n",
  172. "# self.mlp = MLP([hidden_dim * 2, output_dim], dropout=dropout)\n",
  173. " \n",
  174. "# self.loss_fn = CrossEntropyLoss()\n",
  175. "\n",
  176. "# def forward(self, words):\n",
  177. "# output = self.embedding(words)\n",
  178. "# output, (hidden, cell) = self.lstm(output)\n",
  179. "# output = self.mlp(torch.cat((hidden[-1], hidden[-2]), dim=1))\n",
  180. "# return output\n",
  181. " \n",
  182. "# def train_step(self, words, target):\n",
  183. "# pred = self(words)\n",
  184. "# return {\"loss\": self.loss_fn(pred, target)}\n",
  185. "\n",
  186. "# def evaluate_step(self, words, target):\n",
  187. "# pred = self(words)\n",
  188. "# pred = torch.max(pred, dim=-1)[1]\n",
  189. "# return {\"pred\": pred, \"target\": target}"
  190. ]
  191. },
  192. {
  193. "cell_type": "code",
  194. "execution_count": null,
  195. "id": "9dbbf50d",
  196. "metadata": {},
  197. "outputs": [],
  198. "source": [
  199. "# model = ClsByModules(vocab_size=len(vocabulary), embedding_dim=100, output_dim=2)\n",
  200. "\n",
  201. "# from torch.optim import AdamW\n",
  202. "\n",
  203. "# optimizers = AdamW(params=model.parameters(), lr=5e-5)"
  204. ]
  205. },
  206. {
  207. "cell_type": "code",
  208. "execution_count": null,
  209. "id": "7a93432f",
  210. "metadata": {},
  211. "outputs": [],
  212. "source": [
  213. "# from fastNLP import Trainer, Accuracy\n",
  214. "\n",
  215. "# trainer = Trainer(\n",
  216. "# model=model,\n",
  217. "# driver='torch',\n",
  218. "# device=0, # 'cuda'\n",
  219. "# n_epochs=10,\n",
  220. "# optimizers=optimizers,\n",
  221. "# train_dataloader=train_dataloader,\n",
  222. "# evaluate_dataloaders=evaluate_dataloader,\n",
  223. "# metrics={'acc': Accuracy()}\n",
  224. "# )"
  225. ]
  226. },
  227. {
  228. "cell_type": "code",
  229. "execution_count": null,
  230. "id": "31102e0f",
  231. "metadata": {},
  232. "outputs": [],
  233. "source": [
  234. "# trainer.run(num_eval_batch_per_dl=10)"
  235. ]
  236. },
  237. {
  238. "cell_type": "code",
  239. "execution_count": null,
  240. "id": "8bc4bfb2",
  241. "metadata": {},
  242. "outputs": [],
  243. "source": [
  244. "# trainer.evaluator.run()"
  245. ]
  246. },
  247. {
  248. "cell_type": "markdown",
  249. "id": "d9443213",
  250. "metadata": {},
  251. "source": [
  252. "## 2. fastNLP 中 models 模块的介绍\n",
  253. "\n",
  254. "### 2.1 示例一:models 实现 CNN 分类\n",
  255. "\n",
  256. "&emsp; 本示例使用`fastNLP 0.8`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n",
  257. "\n",
  258. "模型使用方面,如上所述,这里使用**基于卷积神经网络`CNN`的预定义文本分类模型`CNNText`**,结构如下所示\n",
  259. "\n",
  260. "&emsp; 首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n",
  261. "\n",
  262. "&emsp; &emsp; **感受野为`1`、`3`、`5`的卷积算子变换至`30`维、`40`维、`50`维的卷积特征**,再将三者拼接\n",
  263. "\n",
  264. "&emsp; 最终再次通过`dropout`层、线性变换层,映射至二元的输出值,对应两个分类结果上的几率`logits`\n",
  265. "\n",
  266. "```\n",
  267. "CNNText(\n",
  268. " (embed): Embedding(\n",
  269. " (embed): Embedding(5194, 100)\n",
  270. " (dropout): Dropout(p=0.0, inplace=False)\n",
  271. " )\n",
  272. " (conv_pool): ConvMaxpool(\n",
  273. " (convs): ModuleList(\n",
  274. " (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n",
  275. " (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
  276. " (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
  277. " )\n",
  278. " )\n",
  279. " (dropout): Dropout(p=0.1, inplace=False)\n",
  280. " (fc): Linear(in_features=120, out_features=2, bias=True)\n",
  281. ")\n",
  282. "```\n",
  283. "\n",
  284. "数据使用方面,此处**使用`datasets`模块中的`load_dataset`函数**,以如下形式,指定`SST-2`数据集自动加载\n",
  285. "\n",
  286. "&emsp; 首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下"
  287. ]
  288. },
  289. {
  290. "cell_type": "code",
  291. "execution_count": null,
  292. "id": "1aa5cf6d",
  293. "metadata": {},
  294. "outputs": [],
  295. "source": [
  296. "from datasets import load_dataset\n",
  297. "\n",
  298. "sst2data = load_dataset('glue', 'sst2')"
  299. ]
  300. },
  301. {
  302. "cell_type": "markdown",
  303. "id": "c476abe7",
  304. "metadata": {},
  305. "source": [
  306. "紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n",
  307. "\n",
  308. "&emsp; **使用`apply_more`函数、`Vocabulary`模块的`from_/index_dataset`函数预处理数据**\n",
  309. "\n",
  310. "&emsp; &emsp; 并结合`delete_field`函数删除字段调整格式,`split`函数划分测试集和验证集\n",
  311. "\n",
  312. "&emsp; **仅保留`'words'`字段表示输入文本单词序号序列、`'target'`字段表示文本对应预测输出结果**\n",
  313. "\n",
  314. "&emsp; &emsp; 两者**对应到`CNNText`中`train_step`函数和`evaluate_step`函数的签名/输入参数**"
  315. ]
  316. },
  317. {
  318. "cell_type": "code",
  319. "execution_count": null,
  320. "id": "357ea748",
  321. "metadata": {},
  322. "outputs": [],
  323. "source": [
  324. "import sys\n",
  325. "sys.path.append('..')\n",
  326. "\n",
  327. "from fastNLP import DataSet\n",
  328. "\n",
  329. "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
  330. "\n",
  331. "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n",
  332. " progress_bar=\"tqdm\")\n",
  333. "dataset.delete_field('sentence')\n",
  334. "dataset.delete_field('label')\n",
  335. "dataset.delete_field('idx')\n",
  336. "\n",
  337. "from fastNLP import Vocabulary\n",
  338. "\n",
  339. "vocab = Vocabulary()\n",
  340. "vocab.from_dataset(dataset, field_name='words')\n",
  341. "vocab.index_dataset(dataset, field_name='words')\n",
  342. "\n",
  343. "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
  344. ]
  345. },
  346. {
  347. "cell_type": "markdown",
  348. "id": "96380c67",
  349. "metadata": {},
  350. "source": [
  351. "然后,使用`tutorial-3`中的知识,**通过`prepare_torch_dataloader`处理数据集得到`dataloader`**"
  352. ]
  353. },
  354. {
  355. "cell_type": "code",
  356. "execution_count": null,
  357. "id": "b9dd1273",
  358. "metadata": {},
  359. "outputs": [],
  360. "source": [
  361. "from fastNLP import prepare_torch_dataloader\n",
  362. "\n",
  363. "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
  364. "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
  365. ]
  366. },
  367. {
  368. "cell_type": "markdown",
  369. "id": "96941b63",
  370. "metadata": {},
  371. "source": [
  372. "接着,**从`fastNLP.models.torch`路径下导入`CNNText`**,初始化`CNNText`实例以及`optimizer`实例\n",
  373. "\n",
  374. "&emsp; 注意:初始化`CNNText`时,**二元组参数`embed`、分类数量`num_classes`是必须传入的**,其中\n",
  375. "\n",
  376. "&emsp; &emsp; **`embed`表示嵌入层的嵌入抽取矩阵大小**,因此第二个元素对应的是默认隐藏层维度 `100`维"
  377. ]
  378. },
  379. {
  380. "cell_type": "code",
  381. "execution_count": null,
  382. "id": "f6e76e2e",
  383. "metadata": {},
  384. "outputs": [],
  385. "source": [
  386. "from fastNLP.models.torch import CNNText\n",
  387. "\n",
  388. "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
  389. "\n",
  390. "from torch.optim import AdamW\n",
  391. "\n",
  392. "optimizers = AdamW(params=model.parameters(), lr=5e-4)"
  393. ]
  394. },
  395. {
  396. "cell_type": "markdown",
  397. "id": "0cc5ca10",
  398. "metadata": {},
  399. "source": [
  400. "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练"
  401. ]
  402. },
  403. {
  404. "cell_type": "code",
  405. "execution_count": null,
  406. "id": "50a13ee5",
  407. "metadata": {},
  408. "outputs": [],
  409. "source": [
  410. "from fastNLP import Trainer, Accuracy\n",
  411. "\n",
  412. "trainer = Trainer(\n",
  413. " model=model,\n",
  414. " driver='torch',\n",
  415. " device=0, # 'cuda'\n",
  416. " n_epochs=10,\n",
  417. " optimizers=optimizers,\n",
  418. " train_dataloader=train_dataloader,\n",
  419. " evaluate_dataloaders=evaluate_dataloader,\n",
  420. " metrics={'acc': Accuracy()}\n",
  421. ")"
  422. ]
  423. },
  424. {
  425. "cell_type": "code",
  426. "execution_count": null,
  427. "id": "28903a7d",
  428. "metadata": {},
  429. "outputs": [],
  430. "source": [
  431. "trainer.run()"
  432. ]
  433. },
  434. {
  435. "cell_type": "code",
  436. "execution_count": null,
  437. "id": "f47a6a35",
  438. "metadata": {},
  439. "outputs": [],
  440. "source": [
  441. "trainer.evaluator.run()"
  442. ]
  443. },
  444. {
  445. "cell_type": "markdown",
  446. "id": "7c811257",
  447. "metadata": {},
  448. "source": [
  449. "&emsp; 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间"
  450. ]
  451. },
  452. {
  453. "cell_type": "code",
  454. "execution_count": null,
  455. "id": "c1a2e2ca",
  456. "metadata": {},
  457. "outputs": [],
  458. "source": [
  459. "import gc\n",
  460. "\n",
  461. "del model\n",
  462. "del trainer\n",
  463. "del dataset\n",
  464. "del sst2data\n",
  465. "\n",
  466. "gc.collect()"
  467. ]
  468. },
  469. {
  470. "cell_type": "markdown",
  471. "id": "6aec2a19",
  472. "metadata": {},
  473. "source": [
  474. "### 2.2 示例二:models 实现 BiLSTM 标注\n",
  475. "\n",
  476. "&emsp; 通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n",
  477. "\n",
  478. "&emsp; &emsp; 针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n",
  479. "\n",
  480. "&emsp; 避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n",
  481. "\n",
  482. "模型使用方面,如上所述,这里使用**基于双向`LSTM`+条件随机场`CRF`的标注模型`BiLSTMCRF`**,结构如下所示\n",
  483. "\n",
  484. "&emsp; 其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n",
  485. "\n",
  486. "```\n",
  487. "BiLSTMCRF(\n",
  488. " (embed): Embedding(7590, 100)\n",
  489. " (lstm): LSTM(\n",
  490. " (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n",
  491. " )\n",
  492. " (dropout): Dropout(p=0.1, inplace=False)\n",
  493. " (fc): Linear(in_features=200, out_features=9, bias=True)\n",
  494. " (crf): ConditionalRandomField()\n",
  495. ")\n",
  496. "```\n",
  497. "\n",
  498. "数据使用方面,此处仍然**使用`datasets`模块中的`load_dataset`函数**,以如下形式,加载`CoNLL-2003`数据集\n",
  499. "\n",
  500. "&emsp; 首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下"
  501. ]
  502. },
  503. {
  504. "cell_type": "code",
  505. "execution_count": null,
  506. "id": "03e66686",
  507. "metadata": {},
  508. "outputs": [],
  509. "source": [
  510. "from datasets import load_dataset\n",
  511. "\n",
  512. "ner2data = load_dataset('conll2003', 'conll2003')"
  513. ]
  514. },
  515. {
  516. "cell_type": "markdown",
  517. "id": "fc505631",
  518. "metadata": {},
  519. "source": [
  520. "紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n",
  521. "\n",
  522. "&emsp; 完成数据集格式调整、文本序列化等操作;此处**需要`'words'`、`'seq_len'`、`'target'`三个字段**\n",
  523. "\n",
  524. "此外,**需要定义`NER`标签到标签序号的映射**(**词汇表`label_vocab`**),数据集中标签已经完成了序号映射\n",
  525. "\n",
  526. "&emsp; 所以需要人工定义**`9`个标签对应之前的`9`个分类目标**;数据集说明中规定,`'O'`表示其他标签\n",
  527. "\n",
  528. "&emsp; **后缀`'-PER'`、`'-ORG'`、`'-LOC'`、`'-MISC'`对应人名、组织名、地名、时间等其他命名**\n",
  529. "\n",
  530. "&emsp; **前缀`'B-'`表示起始标签、`'I-'`表示终止标签**;例如,`'B-PER'`表示人名实体的起始标签"
  531. ]
  532. },
  533. {
  534. "cell_type": "code",
  535. "execution_count": null,
  536. "id": "1f88cad4",
  537. "metadata": {},
  538. "outputs": [],
  539. "source": [
  540. "import sys\n",
  541. "sys.path.append('..')\n",
  542. "\n",
  543. "from fastNLP import DataSet\n",
  544. "\n",
  545. "dataset = DataSet.from_pandas(ner2data['train'].to_pandas())[:4000]\n",
  546. "\n",
  547. "dataset.apply_more(lambda ins:{'words': ins['tokens'], 'seq_len': len(ins['tokens']), 'target': ins['ner_tags']}, \n",
  548. " progress_bar=\"tqdm\")\n",
  549. "dataset.delete_field('tokens')\n",
  550. "dataset.delete_field('ner_tags')\n",
  551. "dataset.delete_field('pos_tags')\n",
  552. "dataset.delete_field('chunk_tags')\n",
  553. "dataset.delete_field('id')\n",
  554. "\n",
  555. "from fastNLP import Vocabulary\n",
  556. "\n",
  557. "token_vocab = Vocabulary()\n",
  558. "token_vocab.from_dataset(dataset, field_name='words')\n",
  559. "token_vocab.index_dataset(dataset, field_name='words')\n",
  560. "label_vocab = Vocabulary(padding=None, unknown=None)\n",
  561. "label_vocab.add_word_lst(['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'])\n",
  562. "\n",
  563. "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
  564. ]
  565. },
  566. {
  567. "cell_type": "markdown",
  568. "id": "d9889427",
  569. "metadata": {},
  570. "source": [
  571. "然后,同样使用`tutorial-3`中的知识,通过`prepare_torch_dataloader`处理数据集得到`dataloader`"
  572. ]
  573. },
  574. {
  575. "cell_type": "code",
  576. "execution_count": null,
  577. "id": "7802a072",
  578. "metadata": {},
  579. "outputs": [],
  580. "source": [
  581. "from fastNLP import prepare_torch_dataloader\n",
  582. "\n",
  583. "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
  584. "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
  585. ]
  586. },
  587. {
  588. "cell_type": "markdown",
  589. "id": "2bc7831b",
  590. "metadata": {},
  591. "source": [
  592. "接着,**从`fastNLP.models.torch`路径下导入`BiLSTMCRF`**,初始化`BiLSTMCRF`实例和优化器\n",
  593. "\n",
  594. "&emsp; 注意:初始化`BiLSTMCRF`时,和`CNNText`相同,**参数`embed`、`num_classes`是必须传入的**\n",
  595. "\n",
  596. "&emsp; &emsp; 隐藏层维度`hidden_size`默认`100`维,调整`150`维;退学概率默认`0.1`,调整`0.2`"
  597. ]
  598. },
  599. {
  600. "cell_type": "code",
  601. "execution_count": null,
  602. "id": "4e12c09f",
  603. "metadata": {},
  604. "outputs": [],
  605. "source": [
  606. "from fastNLP.models.torch import BiLSTMCRF\n",
  607. "\n",
  608. "model = BiLSTMCRF(embed=(len(token_vocab), 150), num_classes=len(label_vocab), \n",
  609. " num_layers=1, hidden_size=150, dropout=0.2)\n",
  610. "\n",
  611. "from torch.optim import AdamW\n",
  612. "\n",
  613. "optimizers = AdamW(params=model.parameters(), lr=1e-3)"
  614. ]
  615. },
  616. {
  617. "cell_type": "markdown",
  618. "id": "bf30608f",
  619. "metadata": {},
  620. "source": [
  621. "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练\n",
  622. "\n",
  623. "&emsp; **使用`SpanFPreRecMetric`作为`NER`的评价标准**,详细请参考接下来的`tutorial-5`\n",
  624. "\n",
  625. "&emsp; 同时,**初始化时需要添加`vocabulary`形式的标签与序号之间的映射`tag_vocab`**"
  626. ]
  627. },
  628. {
  629. "cell_type": "code",
  630. "execution_count": null,
  631. "id": "cbd6c205",
  632. "metadata": {},
  633. "outputs": [],
  634. "source": [
  635. "from fastNLP import Trainer, SpanFPreRecMetric\n",
  636. "\n",
  637. "trainer = Trainer(\n",
  638. " model=model,\n",
  639. " driver='torch',\n",
  640. " device=0, # 'cuda'\n",
  641. " n_epochs=10,\n",
  642. " optimizers=optimizers,\n",
  643. " train_dataloader=train_dataloader,\n",
  644. " evaluate_dataloaders=evaluate_dataloader,\n",
  645. " metrics={'F1': SpanFPreRecMetric(tag_vocab=label_vocab)}\n",
  646. ")"
  647. ]
  648. },
  649. {
  650. "cell_type": "code",
  651. "execution_count": null,
  652. "id": "0f8eff34",
  653. "metadata": {},
  654. "outputs": [],
  655. "source": [
  656. "trainer.run(num_eval_batch_per_dl=10)"
  657. ]
  658. },
  659. {
  660. "cell_type": "code",
  661. "execution_count": null,
  662. "id": "37871d6b",
  663. "metadata": {},
  664. "outputs": [],
  665. "source": [
  666. "trainer.evaluator.run()"
  667. ]
  668. },
  669. {
  670. "cell_type": "code",
  671. "execution_count": null,
  672. "id": "96bae094",
  673. "metadata": {},
  674. "outputs": [],
  675. "source": []
  676. }
  677. ],
  678. "metadata": {
  679. "kernelspec": {
  680. "display_name": "Python 3 (ipykernel)",
  681. "language": "python",
  682. "name": "python3"
  683. },
  684. "language_info": {
  685. "codemirror_mode": {
  686. "name": "ipython",
  687. "version": 3
  688. },
  689. "file_extension": ".py",
  690. "mimetype": "text/x-python",
  691. "name": "python",
  692. "nbconvert_exporter": "python",
  693. "pygments_lexer": "ipython3",
  694. "version": "3.7.13"
  695. }
  696. },
  697. "nbformat": 4,
  698. "nbformat_minor": 5
  699. }