{ "cells": [ { "cell_type": "markdown", "id": "fdd7ff16", "metadata": {}, "source": [ "# T5. fastNLP 中的预定义模型\n", "\n", " 1 fastNLP 中 modules 的介绍\n", " \n", " 1.1 modules 模块、models 模块 简介\n", "\n", " 1.2 示例一:modules 实现 LSTM 分类\n", "\n", " 2 fastNLP 中 models 的介绍\n", " \n", " 2.1 示例一:models 实现 CNN 分类\n", "\n", " 2.3 示例二:models 实现 BiLSTM 标注" ] }, { "cell_type": "markdown", "id": "d3d65d53", "metadata": {}, "source": [ "## 1. fastNLP 中 modules 模块的介绍\n", "\n", "### 1.1 modules 模块、models 模块 简介\n", "\n", "在`fastNLP 0.8`中,**`modules.torch`路径下定义了一些基于`pytorch`实现的基础模块**\n", "\n", " 包括长短期记忆网络`LSTM`、条件随机场`CRF`、`transformer`的编解码器模块等,详见下表\n", "\n", "|
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Processing: 0%| | 0/6000 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import sys\n", "sys.path.append('..')\n", "\n", "from fastNLP import DataSet\n", "\n", "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", "\n", "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", " progress_bar=\"tqdm\")\n", "dataset.delete_field('sentence')\n", "dataset.delete_field('label')\n", "dataset.delete_field('idx')\n", "\n", "from fastNLP import Vocabulary\n", "\n", "vocab = Vocabulary()\n", "vocab.from_dataset(dataset, field_name='words')\n", "vocab.index_dataset(dataset, field_name='words')\n", "\n", "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" ] }, { "cell_type": "markdown", "id": "96380c67", "metadata": {}, "source": [ "然后,使用`tutorial-3`中的知识,**通过`prepare_torch_dataloader`处理数据集得到`dataloader`**" ] }, { "cell_type": "code", "execution_count": 10, "id": "b9dd1273", "metadata": {}, "outputs": [], "source": [ "from fastNLP import prepare_torch_dataloader\n", "\n", "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" ] }, { "cell_type": "markdown", "id": "96941b63", "metadata": {}, "source": [ "接着,**从`fastNLP.models.torch`路径下导入`CNNText`**,初始化`CNNText`实例以及`optimizer`实例\n", "\n", " 注意:初始化`CNNText`时,**二元组参数`embed`、分类数量`num_classes`是必须传入的**,其中\n", "\n", " **`embed`表示嵌入层的嵌入抽取矩阵大小**,因此第二个元素对应的是默认隐藏层维度 `100`维" ] }, { "cell_type": "code", "execution_count": 11, "id": "f6e76e2e", "metadata": {}, "outputs": [], "source": [ "from fastNLP.models.torch import CNNText\n", "\n", "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n", "\n", "from torch.optim import AdamW\n", "\n", "optimizers = AdamW(params=model.parameters(), lr=5e-4)" ] }, { "cell_type": "markdown", "id": "0cc5ca10", "metadata": {}, "source": [ "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练" ] }, { "cell_type": "code", "execution_count": 12, "id": "50a13ee5", "metadata": {}, "outputs": [], "source": [ "from fastNLP import Trainer, Accuracy\n", "\n", "trainer = Trainer(\n", " model=model,\n", " driver='torch',\n", " device=0, # 'cuda'\n", " n_epochs=10,\n", " optimizers=optimizers,\n", " train_dataloader=train_dataloader,\n", " evaluate_dataloaders=evaluate_dataloader,\n", " metrics={'acc': Accuracy()}\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "28903a7d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
[17:45:59] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", "\n" ], "text/plain": [ "\u001b[2;36m[17:45:59]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=147745;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=708408;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.575,\n", " \"total#acc\": 160.0,\n", " \"correct#acc\": 92.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.575\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m92.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.75625,\n", " \"total#acc\": 160.0,\n", " \"correct#acc\": 121.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.78125,\n", " \"total#acc\": 160.0,\n", " \"correct#acc\": 125.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.8,\n", " \"total#acc\": 160.0,\n", " \"correct#acc\": 128.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.79375,\n", " \"total#acc\": 160.0,\n", " \"correct#acc\": 127.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.80625,\n", " \"total#acc\": 160.0,\n", " \"correct#acc\": 129.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.81875,\n", " \"total#acc\": 160.0,\n", " \"correct#acc\": 131.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.825,\n", " \"total#acc\": 160.0,\n", " \"correct#acc\": 132.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.825\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m132.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.81875,\n", " \"total#acc\": 160.0,\n", " \"correct#acc\": 131.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.81875,\n", " \"total#acc\": 160.0,\n", " \"correct#acc\": 131.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", "execution_count": 14, "id": "f47a6a35", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'acc#acc': 0.79, 'total#acc': 900.0, 'correct#acc': 711.0}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluator.run()" ] }, { "cell_type": "markdown", "id": "7c811257", "metadata": {}, "source": [ " 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间" ] }, { "cell_type": "code", "execution_count": 15, "id": "c1a2e2ca", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "342" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gc\n", "\n", "del model\n", "del trainer\n", "del dataset\n", "del sst2data\n", "\n", "gc.collect()" ] }, { "cell_type": "markdown", "id": "6aec2a19", "metadata": {}, "source": [ "### 2.2 示例二:models 实现 BiLSTM 标注\n", "\n", " 通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n", "\n", " 针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n", "\n", " 避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n", "\n", "模型使用方面,如上所述,这里使用**基于双向`LSTM`+条件随机场`CRF`的标注模型`BiLSTMCRF`**,结构如下所示\n", "\n", " 其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n", "\n", "```\n", "BiLSTMCRF(\n", " (embed): Embedding(7590, 100)\n", " (lstm): LSTM(\n", " (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n", " )\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (fc): Linear(in_features=200, out_features=9, bias=True)\n", " (crf): ConditionalRandomField()\n", ")\n", "```\n", "\n", "数据使用方面,此处仍然**使用`datasets`模块中的`load_dataset`函数**,以如下形式,加载`CoNLL-2003`数据集\n", "\n", " 首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下" ] }, { "cell_type": "code", "execution_count": 16, "id": "03e66686", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3ec9e0ce9a054339a2453420c2c9f28b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from datasets import load_dataset\n", "\n", "ner2data = load_dataset('conll2003', 'conll2003')" ] }, { "cell_type": "markdown", "id": "fc505631", "metadata": {}, "source": [ "紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n", "\n", " 完成数据集格式调整、文本序列化等操作;此处**需要`'words'`、`'seq_len'`、`'target'`三个字段**\n", "\n", "此外,**需要定义`NER`标签到标签序号的映射**(**词汇表`label_vocab`**),数据集中标签已经完成了序号映射\n", "\n", " 所以需要人工定义**`9`个标签对应之前的`9`个分类目标**;数据集说明中规定,`'O'`表示其他标签\n", "\n", " **后缀`'-PER'`、`'-ORG'`、`'-LOC'`、`'-MISC'`对应人名、组织名、地名、时间等其他命名**\n", "\n", " **前缀`'B-'`表示起始标签、`'I-'`表示终止标签**;例如,`'B-PER'`表示人名实体的起始标签" ] }, { "cell_type": "code", "execution_count": 17, "id": "1f88cad4", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Processing: 0%| | 0/4000 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import sys\n", "sys.path.append('..')\n", "\n", "from fastNLP import DataSet\n", "\n", "dataset = DataSet.from_pandas(ner2data['train'].to_pandas())[:4000]\n", "\n", "dataset.apply_more(lambda ins:{'words': ins['tokens'], 'seq_len': len(ins['tokens']), 'target': ins['ner_tags']}, \n", " progress_bar=\"tqdm\")\n", "dataset.delete_field('tokens')\n", "dataset.delete_field('ner_tags')\n", "dataset.delete_field('pos_tags')\n", "dataset.delete_field('chunk_tags')\n", "dataset.delete_field('id')\n", "\n", "from fastNLP import Vocabulary\n", "\n", "token_vocab = Vocabulary()\n", "token_vocab.from_dataset(dataset, field_name='words')\n", "token_vocab.index_dataset(dataset, field_name='words')\n", "label_vocab = Vocabulary(padding=None, unknown=None)\n", "label_vocab.add_word_lst(['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'])\n", "\n", "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" ] }, { "cell_type": "markdown", "id": "d9889427", "metadata": {}, "source": [ "然后,同样使用`tutorial-3`中的知识,通过`prepare_torch_dataloader`处理数据集得到`dataloader`" ] }, { "cell_type": "code", "execution_count": 18, "id": "7802a072", "metadata": {}, "outputs": [], "source": [ "from fastNLP import prepare_torch_dataloader\n", "\n", "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" ] }, { "cell_type": "markdown", "id": "2bc7831b", "metadata": {}, "source": [ "接着,**从`fastNLP.models.torch`路径下导入`BiLSTMCRF`**,初始化`BiLSTMCRF`实例和优化器\n", "\n", " 注意:初始化`BiLSTMCRF`时,和`CNNText`相同,**参数`embed`、`num_classes`是必须传入的**\n", "\n", " 隐藏层维度`hidden_size`默认`100`维,调整`150`维;退学概率默认`0.1`,调整`0.2`" ] }, { "cell_type": "code", "execution_count": 19, "id": "4e12c09f", "metadata": {}, "outputs": [], "source": [ "from fastNLP.models.torch import BiLSTMCRF\n", "\n", "model = BiLSTMCRF(embed=(len(token_vocab), 150), num_classes=len(label_vocab), \n", " num_layers=1, hidden_size=150, dropout=0.2)\n", "\n", "from torch.optim import AdamW\n", "\n", "optimizers = AdamW(params=model.parameters(), lr=1e-3)" ] }, { "cell_type": "markdown", "id": "bf30608f", "metadata": {}, "source": [ "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练\n", "\n", " 参考`tutorial-4`中的内容,**使用`SpanFPreRecMetric`作为`NER`的评价标准**\n", "\n", " 同时,**初始化时需要添加`vocabulary`形式的标签与序号之间的映射`tag_vocab`**" ] }, { "cell_type": "code", "execution_count": 20, "id": "cbd6c205", "metadata": {}, "outputs": [], "source": [ "from fastNLP import Trainer, SpanFPreRecMetric\n", "\n", "trainer = Trainer(\n", " model=model,\n", " driver='torch',\n", " device=0, # 'cuda'\n", " n_epochs=10,\n", " optimizers=optimizers,\n", " train_dataloader=train_dataloader,\n", " evaluate_dataloaders=evaluate_dataloader,\n", " metrics={'F1': SpanFPreRecMetric(tag_vocab=label_vocab)}\n", ")" ] }, { "cell_type": "code", "execution_count": 21, "id": "0f8eff34", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
[17:49:16] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", "\n" ], "text/plain": [ "\u001b[2;36m[17:49:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=766109;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=787419;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"f#F1\": 0.220374,\n", " \"pre#F1\": 0.25,\n", " \"rec#F1\": 0.197026\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.220374\u001b[0m,\n", " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.25\u001b[0m,\n", " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.197026\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"f#F1\": 0.442857,\n", " \"pre#F1\": 0.426117,\n", " \"rec#F1\": 0.460967\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.442857\u001b[0m,\n", " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.426117\u001b[0m,\n", " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.460967\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"f#F1\": 0.572954,\n", " \"pre#F1\": 0.549488,\n", " \"rec#F1\": 0.598513\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.572954\u001b[0m,\n", " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.549488\u001b[0m,\n", " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.598513\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"f#F1\": 0.665399,\n", " \"pre#F1\": 0.680934,\n", " \"rec#F1\": 0.650558\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.665399\u001b[0m,\n", " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.680934\u001b[0m,\n", " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.650558\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"f#F1\": 0.734694,\n", " \"pre#F1\": 0.733333,\n", " \"rec#F1\": 0.736059\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.734694\u001b[0m,\n", " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.733333\u001b[0m,\n", " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.736059\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"f#F1\": 0.742647,\n", " \"pre#F1\": 0.734545,\n", " \"rec#F1\": 0.750929\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.742647\u001b[0m,\n", " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.734545\u001b[0m,\n", " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.750929\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"f#F1\": 0.773585,\n", " \"pre#F1\": 0.785441,\n", " \"rec#F1\": 0.762082\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.773585\u001b[0m,\n", " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.785441\u001b[0m,\n", " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.762082\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"f#F1\": 0.770115,\n", " \"pre#F1\": 0.794466,\n", " \"rec#F1\": 0.747212\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.770115\u001b[0m,\n", " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.794466\u001b[0m,\n", " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.747212\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"f#F1\": 0.7603,\n", " \"pre#F1\": 0.766038,\n", " \"rec#F1\": 0.754647\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.7603\u001b[0m,\n", " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.766038\u001b[0m,\n", " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.754647\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"f#F1\": 0.743682,\n", " \"pre#F1\": 0.722807,\n", " \"rec#F1\": 0.765799\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.743682\u001b[0m,\n", " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.722807\u001b[0m,\n", " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.765799\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", "execution_count": 22, "id": "37871d6b", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'f#F1': 0.75283, 'pre#F1': 0.727438, 'rec#F1': 0.780059}" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluator.run()" ] }, { "cell_type": "code", "execution_count": null, "id": "96bae094", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 5 }