|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "id": "6011adf8",
- "metadata": {},
- "source": [
- "# 10 分钟快速上手 fastNLP torch\n",
- "\n",
- "在这个例子中,我们将使用BERT来解决conll2003数据集中的命名实体识别任务。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "e166c051",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "--2022-07-07 10:12:29-- https://data.deepai.org/conll2003.zip\n",
- "Resolving data.deepai.org (data.deepai.org)... 138.201.36.183\n",
- "Connecting to data.deepai.org (data.deepai.org)|138.201.36.183|:443... connected.\n",
- "WARNING: cannot verify data.deepai.org's certificate, issued by ‘CN=R3,O=Let's Encrypt,C=US’:\n",
- " Issued certificate has expired.\n",
- "HTTP request sent, awaiting response... 200 OK\n",
- "Length: 982975 (960K) [application/x-zip-compressed]\n",
- "Saving to: ‘conll2003.zip’\n",
- "\n",
- "conll2003.zip 100%[===================>] 959.94K 653KB/s in 1.5s \n",
- "\n",
- "2022-07-07 10:12:32 (653 KB/s) - ‘conll2003.zip’ saved [982975/982975]\n",
- "\n",
- "Archive: conll2003.zip\n",
- " inflating: conll2003/metadata \n",
- " inflating: conll2003/test.txt \n",
- " inflating: conll2003/train.txt \n",
- " inflating: conll2003/valid.txt \n"
- ]
- }
- ],
- "source": [
- "# Linux/Mac 下载数据,并解压\n",
- "import platform\n",
- "if platform.system() != \"Windows\":\n",
- " !wget https://data.deepai.org/conll2003.zip --no-check-certificate -O conll2003.zip\n",
- " !unzip conll2003.zip -d conll2003\n",
- "# Windows用户请通过复制该url到浏览器下载该数据并解压"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "f7acbf1f",
- "metadata": {},
- "source": [
- "## 目录\n",
- "接下来我们将按照以下的内容介绍在如何通过fastNLP减少工程性代码的撰写 \n",
- "- 1. 数据加载\n",
- "- 2. 数据预处理、数据缓存\n",
- "- 3. DataLoader\n",
- "- 4. 模型准备\n",
- "- 5. Trainer的使用\n",
- "- 6. Evaluator的使用\n",
- "- 7. 其它【待补充】\n",
- " - 7.1 使用多卡进行训练、评测\n",
- " - 7.2 使用ZeRO优化\n",
- " - 7.3 通过overfit测试快速验证模型\n",
- " - 7.4 复杂Monitor的使用\n",
- " - 7.5 训练过程中,使用不同的测试函数\n",
- " - 7.6 更有效率的Sampler\n",
- " - 7.7 保存模型\n",
- " - 7.8 断点重训\n",
- " - 7.9 使用huggingface datasets\n",
- " - 7.10 使用torchmetrics来作为metric\n",
- " - 7.11 将预测结果写出到文件\n",
- " - 7.12 混合 dataset 训练\n",
- " - 7.13 logger的使用\n",
- " - 7.14 自定义分布式 Metric 。\n",
- " - 7.15 通过batch_step_fn实现R-Drop"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0657dfba",
- "metadata": {},
- "source": [
- "#### 1. 数据加载\n",
- "目前在``conll2003``目录下有``train.txt``, ``test.txt``与``valid.txt``三个文件,文件的格式为[conll格式](https://universaldependencies.org/format.html),其编码格式为 [BIO](https://blog.csdn.net/HappyRocking/article/details/79716212) 类型。可以通过继承 fastNLP.io.Loader 来简化加载过程,继承了 Loader 函数后,只需要在实现读取单个文件 _load() 函数即可。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "c557f0ba",
- "metadata": {},
- "outputs": [],
- "source": [
- "import sys\n",
- "sys.path.append('../..')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "6f59e438",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "In total 3 datasets:\n",
- "\ttrain has 14987 instances.\n",
- "\ttest has 3684 instances.\n",
- "\tdev has 3466 instances.\n",
- "\n"
- ]
- }
- ],
- "source": [
- "from fastNLP import DataSet, Instance\n",
- "from fastNLP.io import Loader\n",
- "\n",
- "\n",
- "# 继承Loader之后,我们只需要实现其中_load()方法,_load()方法传入一个文件路径,返回一个fastNLP DataSet对象,其目的是读取一个文件。\n",
- "class ConllLoader(Loader):\n",
- " def _load(self, path):\n",
- " ds = DataSet()\n",
- " with open(path, 'r') as f:\n",
- " segments = []\n",
- " for line in f:\n",
- " line = line.strip()\n",
- " if line == '': # 如果为空行,说明需要切换到下一句了。\n",
- " if segments:\n",
- " raw_words = [s[0] for s in segments]\n",
- " raw_target = [s[1] for s in segments]\n",
- " # 将一个 sample 插入到 DataSet中\n",
- " ds.append(Instance(raw_words=raw_words, raw_target=raw_target)) \n",
- " segments = []\n",
- " else:\n",
- " parts = line.split()\n",
- " assert len(parts)==4\n",
- " segments.append([parts[0], parts[-1]])\n",
- " return ds\n",
- " \n",
- "\n",
- "# 直接使用 load() 方法加载数据集, 返回的 data_bundle 是一个 fastNLP.io.DataBundle 对象,该对象相当于将多个 dataset 放置在一起,\n",
- "# 可以方便之后的预处理,DataBundle 支持的接口可以在 !!! 查看。\n",
- "data_bundle = ConllLoader().load({\n",
- " 'train': 'conll2003/train.txt',\n",
- " 'test': 'conll2003/test.txt',\n",
- " 'dev': 'conll2003/valid.txt'\n",
- "})\n",
- "\"\"\"\n",
- "也可以通过 ConllLoader().load('conll2003/') 来读取,其原理是load()函数将尝试从'conll2003/'文件夹下寻找文件名称中包含了\n",
- "'train'、'test'和'dev'的文件,并分别读取将其命名为'train'、'test'和'dev'(如文件夹中同一个关键字出现在了多个文件名中将导致报错,\n",
- "此时请通过dict的方式传入路径信息)。但在我们这里的数据里,没有文件包含dev,所以无法直接使用文件夹读取,转而通过dict的方式传入读取的路径,\n",
- "该dict的key也将作为读取的数据集的名称,value即对应的文件路径。\n",
- "\"\"\"\n",
- "\n",
- "print(data_bundle) # 打印 data_bundle 可以查看包含的 DataSet \n",
- "# data_bundle.get_dataset('train') # 可以获取单个 dataset"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "57ae314d",
- "metadata": {},
- "source": [
- "#### 2. 数据预处理\n",
- "接下来,我们将演示如何通过fastNLP提供的apply函数方便快捷地进行预处理。我们需要进行的预处理操作有: \n",
- "(1)使用BertTokenizer将文本转换为index;同时记录每个word被bpe之后第一个bpe的index,用于得到word的hidden state; \n",
- "(2)使用[Vocabulary](../fastNLP)来将raw_target转换为序号。 "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "96389988",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Output()"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
- ],
- "text/plain": []
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "c3bd41a323c94a41b409d29a5d4079b6",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Output()"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "IOPub message rate exceeded.\n",
- "The notebook server will temporarily stop sending output\n",
- "to the client in order to avoid crashing it.\n",
- "To change this limit, set the config variable\n",
- "`--NotebookApp.iopub_msg_rate_limit`.\n",
- "\n",
- "Current values:\n",
- "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
- "NotebookApp.rate_limit_window=3.0 (secs)\n",
- "\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
- ],
- "text/plain": []
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:48:13] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Save cache to <span style=\"color: #800080; text-decoration-color: #800080\">/remote-home/hyan01/exps/fastNLP/fastN</span> <a href=\"file://../../fastNLP/core/utils/cache_results.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">cache_results.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/utils/cache_results.py#332\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">332</span></a>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #800080; text-decoration-color: #800080\">LP/demo/torch_tutorial/caches/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">c7f74559_cache.pkl.</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[2;36m[10:48:13]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Save cache to \u001b[35m/remote-home/hyan01/exps/fastNLP/fastN\u001b[0m \u001b]8;id=831330;file://../../fastNLP/core/utils/cache_results.py\u001b\\\u001b[2mcache_results.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=609545;file://../../fastNLP/core/utils/cache_results.py#332\u001b\\\u001b[2m332\u001b[0m\u001b]8;;\u001b\\\n",
- "\u001b[2;36m \u001b[0m \u001b[35mLP/demo/torch_tutorial/caches/\u001b[0m\u001b[95mc7f74559_cache.pkl.\u001b[0m \u001b[2m \u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "# fastNLP 中提供了BERT, RoBERTa, GPT, BART 模型,更多的预训练模型请直接使用transformers\n",
- "from fastNLP.transformers.torch import BertTokenizer\n",
- "from fastNLP import cache_results, Vocabulary\n",
- "\n",
- "# 使用cache_results来装饰函数,会将函数的返回结果缓存到'caches/{param_hash_id}_cache.pkl'路径中(其中{param_hash_id}是根据\n",
- "# 传递给 process_data 函数参数决定的,因此当函数的参数变化时,会再生成新的缓存文件。如果需要重新生成新的缓存,(a) 可以在调用process_data\n",
- "# 函数时,额外传入一个_refresh=True的参数; 或者(b)删除相应的缓存文件。此外,保存结果时,cache_results默认还会\n",
- "# 记录 process_data 函数源码的hash值,当其源码发生了变动,直接读取缓存会发出警告,以防止在修改预处理代码之后,忘记刷新缓存。)\n",
- "@cache_results('caches/cache.pkl')\n",
- "def process_data(data_bundle, model_name):\n",
- " tokenizer = BertTokenizer.from_pretrained(model_name)\n",
- " def bpe(raw_words):\n",
- " bpes = [tokenizer.cls_token_id]\n",
- " first = [0]\n",
- " first_index = 1 # 记录第一个bpe的位置\n",
- " for word in raw_words:\n",
- " bpe = tokenizer.encode(word, add_special_tokens=False)\n",
- " bpes.extend(bpe)\n",
- " first.append(first_index)\n",
- " first_index += len(bpe)\n",
- " bpes.append(tokenizer.sep_token_id)\n",
- " first.append(first_index)\n",
- " return {'input_ids': bpes, 'input_len': len(bpes), 'first': first, 'first_len': len(raw_words)}\n",
- " # 对data_bundle中每个dataset的每一条数据中的raw_words使用bpe函数,并且将返回的结果加入到每条数据中。\n",
- " data_bundle.apply_field_more(bpe, field_name='raw_words', num_proc=4)\n",
- " # 对应我们还有 apply_field() 函数,该函数和 apply_field_more() 的区别在于传入到 apply_field() 中的函数应该返回一个 field 的\n",
- " # 内容(即不需要用dict包裹了)。此外,我们还提供了 data_bundle.apply() ,传入 apply() 的函数需要支持传入一个Instance对象,\n",
- " # 更多信息可以参考对应的文档。\n",
- " \n",
- " # tag的词表,由于这是词表,所以不需要有padding和unk\n",
- " tag_vocab = Vocabulary(padding=None, unknown=None)\n",
- " # 从 train 数据的 raw_target 中获取建立词表\n",
- " tag_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_target')\n",
- " # 使用词表将每个 dataset 中的raw_target转为数字,并且将写入到target这个field中\n",
- " tag_vocab.index_dataset(data_bundle.datasets.values(), field_name='raw_target', new_field_name='target')\n",
- " \n",
- " # 可以将 vocabulary 绑定到 data_bundle 上,方便之后使用。\n",
- " data_bundle.set_vocab(tag_vocab, field_name='target')\n",
- " \n",
- " return data_bundle, tokenizer\n",
- "\n",
- "data_bundle, tokenizer = process_data(data_bundle, 'bert-base-cased', _refresh=True) # 第一次调用耗时较长,第二次调用则会直接读取缓存的文件\n",
- "# data_bundle = process_data(data_bundle, 'bert-base-uncased') # 由于参数变化,fastNLP 会再次生成新的缓存文件。 "
- ]
- },
- {
- "cell_type": "markdown",
- "id": "80036fcd",
- "metadata": {},
- "source": [
- "### 3. DataLoader \n",
- "由于现在的深度学习算法大都基于 mini-batch 进行优化,因此需要将多个 sample 组合成一个 batch 再输入到模型之中。在自然语言处理中,不同的 sample 往往长度不一致,需要进行 padding 操作。在fastNLP中,我们使用 fastNLP.TorchDataLoader 帮助用户快速进行 padding ,我们使用了 !!!fastNLP.Collator!!! 对象来进行 pad ,Collator 会在迭代过程中根据第一个 batch 的数据自动判定每个 field 是否可以进行 pad ,可以通过 Collator.set_pad() 函数修改某个 field 的 pad 行为。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "09494695",
- "metadata": {},
- "outputs": [],
- "source": [
- "from fastNLP import prepare_dataloader\n",
- "\n",
- "# 将 data_bundle 中每个 dataset 取出并构造出相应的 DataLoader 对象。返回的 dls 是一个 dict ,包含了 'train', 'test', 'dev' 三个\n",
- "# fastNLP.TorchDataLoader 对象。\n",
- "dls = prepare_dataloader(data_bundle, batch_size=24) \n",
- "\n",
- "\n",
- "# fastNLP 将默认尝试对所有 field 都进行 pad ,如果当前 field 是不可 pad 的类型,则不进行pad;如果是可以 pad 的类型\n",
- "# 默认使用 0 进行 pad 。\n",
- "for dl in dls.values():\n",
- " # 可以通过 set_pad 修改 padding 的行为。\n",
- " dl.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
- " # 如果希望忽略某个 field ,可以通过 set_ignore 方法。\n",
- " dl.set_ignore('raw_target')\n",
- " dl.set_pad('target', pad_val=-100)\n",
- "# 另一种设置的方法是,可以在 dls = prepare_dataloader(data_bundle, batch_size=32) 之前直接调用 \n",
- "# data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id); data_bundle.set_ignore('raw_target')来进行设置。\n",
- "# DataSet 也支持这两个方法。\n",
- "# 若此时调用 batch = next(dls['train']),则 batch 是一个 dict ,其中包含了\n",
- "# 'input_ids': torch.LongTensor([batch_size, max_len])\n",
- "# 'input_len': torch.LongTensor([batch_size])\n",
- "# 'first': torch.LongTensor([batch_size, max_len'])\n",
- "# 'first_len': torch.LongTensor([batch_size])\n",
- "# 'target': torch.LongTensor([batch_size, max_len'-2])\n",
- "# 'raw_words': List[List[str]] # 因为无法判断,所以 Collator 不会做任何处理"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "3583df6d",
- "metadata": {},
- "source": [
- "### 4. 模型准备\n",
- "传入给fastNLP的模型,需要有两个特殊的方法``train_step``、``evaluate_step``,前者默认在 fastNLP.Trainer 中进行调用,后者默认在 fastNLP.Evaluator 中调用。如果模型中没有``train_step``方法,则Trainer会直接使用模型的``forward``函数;如果模型没有``evaluate_step``方法,则Evaluator会直接使用模型的``forward``函数。``train_step``方法(或当其不存在时,``forward``方法)的返回值必须为 dict 类型,并且必须包含``loss``这个 key 。\n",
- "\n",
- "此外fastNLP会使用形参名匹配的方式进行参数传递,例如以下模型\n",
- "```python\n",
- "class Model(nn.Module):\n",
- " def train_step(self, x, y):\n",
- " return {'loss': (x-y).abs().mean()}\n",
- "```\n",
- "fastNLP将尝试从 DataLoader 返回的 batch(假设包含的 key 为 input_ids, target) 中寻找 'x' 和 'y' 这两个 key ,如果没有找到则会报错。有以下的方法可以解决报错\n",
- "- 修改 train_step 的参数为(input_ids, target),以保证和 DataLoader 返回的 batch 中的 key 匹配\n",
- "- 修改 DataLoader 中返回 batch 的 key 的名字为 (x, y)\n",
- "- 在 Trainer 中传入参数 train_input_mapping={'input_ids': 'x', 'target': 'y'} 将输入进行映射,train_input_mapping 也可以是一个函数,更多 train_input_mapping 的介绍可以参考文档。\n",
- "\n",
- "``evaluate_step``也是使用同样的匹配方式,前两条解决方法是一致的,第三种解决方案中,需要在 Evaluator 中传入 evaluate_input_mapping={'input_ids': 'x', 'target': 'y'}。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "f131c1a3",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:48:21] </span><span style=\"color: #800000; text-decoration-color: #800000\">WARNING </span> Some weights of the model checkpoint at <a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">modeling_utils.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py#1490\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1490</span></a>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> bert-base-uncased were not used when initializing <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertModel: <span style=\"font-weight: bold\">[</span><span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.bias'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.LayerNorm.weight'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.seq_relationship.weight'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.decoder.weight'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.dense.weight'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.LayerNorm.bias'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.dense.bias'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.seq_relationship.bias'</span><span style=\"font-weight: bold\">]</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> - This IS expected if you are initializing <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertModel from the checkpoint of a model trained <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> on another task or with another architecture <span style=\"font-weight: bold\">(</span>e.g. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> initializing a BertForSequenceClassification model <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> from a BertForPreTraining model<span style=\"font-weight: bold\">)</span>. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> - This IS NOT expected if you are initializing <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertModel from the checkpoint of a model that you <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> expect to be exactly identical <span style=\"font-weight: bold\">(</span>initializing a <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertForSequenceClassification model from a <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertForSequenceClassification model<span style=\"font-weight: bold\">)</span>. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[2;36m[10:48:21]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m Some weights of the model checkpoint at \u001b]8;id=387614;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=648168;file://../../fastNLP/transformers/torch/modeling_utils.py#1490\u001b\\\u001b[2m1490\u001b[0m\u001b]8;;\u001b\\\n",
- "\u001b[2;36m \u001b[0m bert-base-uncased were not used when initializing \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m BertModel: \u001b[1m[\u001b[0m\u001b[32m'cls.predictions.bias'\u001b[0m, \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.decoder.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.bias'\u001b[0m, \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.bias'\u001b[0m, \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.bias'\u001b[0m\u001b[1m]\u001b[0m \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m - This IS expected if you are initializing \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model trained \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m on another task or with another architecture \u001b[1m(\u001b[0me.g. \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m initializing a BertForSequenceClassification model \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m from a BertForPreTraining model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m - This IS NOT expected if you are initializing \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model that you \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m expect to be exactly identical \u001b[1m(\u001b[0minitializing a \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m BertForSequenceClassification model from a \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m BertForSequenceClassification model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> All the weights of BertModel were initialized from <a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">modeling_utils.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py#1507\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1507</span></a>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> the model checkpoint at bert-base-uncased. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> If your task is similar to the task the model of <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> the checkpoint was trained on, you can already use <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertModel for predictions without further <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> training. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m All the weights of BertModel were initialized from \u001b]8;id=544687;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=934505;file://../../fastNLP/transformers/torch/modeling_utils.py#1507\u001b\\\u001b[2m1507\u001b[0m\u001b]8;;\u001b\\\n",
- "\u001b[2;36m \u001b[0m the model checkpoint at bert-base-uncased. \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m If your task is similar to the task the model of \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m the checkpoint was trained on, you can already use \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m BertModel for predictions without further \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m training. \u001b[2m \u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "import torch\n",
- "from torch import nn\n",
- "from torch.nn.utils.rnn import pad_sequence\n",
- "from fastNLP.transformers.torch import BertModel\n",
- "from fastNLP import seq_len_to_mask\n",
- "import torch.nn.functional as F\n",
- "\n",
- "\n",
- "class BertNER(nn.Module):\n",
- " def __init__(self, model_name, num_class, tag_vocab=None):\n",
- " super().__init__()\n",
- " self.bert = BertModel.from_pretrained(model_name)\n",
- " self.mlp = nn.Sequential(nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),\n",
- " nn.Dropout(0.3),\n",
- " nn.Linear(self.bert.config.hidden_size, num_class))\n",
- " self.tag_vocab = tag_vocab # 这里传入 tag_vocab 的目的是为了演示 constrined_decode \n",
- " if tag_vocab is not None:\n",
- " self._init_constrained_transition()\n",
- " \n",
- " def forward(self, input_ids, input_len, first):\n",
- " attention_mask = seq_len_to_mask(input_len)\n",
- " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
- " last_hidden_state = outputs.last_hidden_state\n",
- " first = first.unsqueeze(-1).repeat(1, 1, last_hidden_state.size(-1))\n",
- " first_bpe_state = last_hidden_state.gather(dim=1, index=first)\n",
- " first_bpe_state = first_bpe_state[:, 1:-1] # 删除 cls 和 sep\n",
- " \n",
- " pred = self.mlp(first_bpe_state)\n",
- " return {'pred': pred}\n",
- " \n",
- " def train_step(self, input_ids, input_len, first, target):\n",
- " pred = self(input_ids, input_len, first)['pred']\n",
- " loss = F.cross_entropy(pred.transpose(1, 2), target)\n",
- " return {'loss': loss}\n",
- " \n",
- " def evaluate_step(self, input_ids, input_len, first):\n",
- " pred = self(input_ids, input_len, first)['pred'].argmax(dim=-1)\n",
- " return {'pred': pred}\n",
- " \n",
- " def constrained_decode(self, input_ids, input_len, first, first_len):\n",
- " # 这个函数在推理时,将保证解码出来的 tag 一定不与前一个 tag 矛盾【例如一定不会出现 B-person 后面接着 I-Location 的情况】\n",
- " # 本身这个需求可以在 Metric 中实现,这里在模型中实现的目的是为了方便演示:如何在fastNLP中使用不同的评测函数\n",
- " pred = self(input_ids, input_len, first)['pred']\n",
- " cons_pred = []\n",
- " for _pred, _len in zip(pred, first_len):\n",
- " _pred = _pred[:_len]\n",
- " tags = [_pred[0].argmax(dim=-1).item()] # 这里就不考虑第一个位置非法的情况了\n",
- " for i in range(1, _len):\n",
- " tags.append((_pred[i] + self.transition[tags[-1]]).argmax().item())\n",
- " cons_pred.append(torch.LongTensor(tags))\n",
- " cons_pred = pad_sequence(cons_pred, batch_first=True)\n",
- " return {'pred': cons_pred}\n",
- " \n",
- " def _init_constrained_transition(self):\n",
- " from fastNLP.modules.torch import allowed_transitions\n",
- " allowed_trans = allowed_transitions(self.tag_vocab)\n",
- " transition = torch.ones((len(self.tag_vocab), len(self.tag_vocab)))*-100000.0\n",
- " for s, e in allowed_trans:\n",
- " transition[s, e] = 0\n",
- " self.register_buffer('transition', transition)\n",
- "\n",
- "model = BertNER('bert-base-uncased', len(data_bundle.get_vocab('target')), data_bundle.get_vocab('target'))"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5aeee1e9",
- "metadata": {},
- "source": [
- "### Trainer 的使用\n",
- "fastNLP 的 Trainer 是用于对模型进行训练的部件。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "f4250f0b",
- "metadata": {
- "scrolled": false
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:49:22] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../../fastNLP/core/controllers/trainer.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/controllers/trainer.py#661\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">661</span></a>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[2;36m[10:49:22]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=246773;file://../../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=639347;file://../../fastNLP/core/controllers/trainer.py#661\u001b\\\u001b[2m661\u001b[0m\u001b]8;;\u001b\\\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Output()"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
- ],
- "text/plain": []
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Output()"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #00d75f; text-decoration-color: #00d75f\">+++++++++++++++++++++++++++++ </span><span style=\"font-weight: bold\">Eval. results on Epoch:</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span><span style=\"font-weight: bold\">, Batch:</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span><span style=\"color: #00d75f; text-decoration-color: #00d75f\"> +++++++++++++++++++++++++++++</span>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[38;5;41m+++++++++++++++++++++++++++++ \u001b[0m\u001b[1mEval. results on Epoch:\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m, Batch:\u001b[0m\u001b[1;36m0\u001b[0m\u001b[38;5;41m +++++++++++++++++++++++++++++\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
- " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#f\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.402447</span>,\n",
- " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#f\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.447906</span>,\n",
- " <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#f\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.365365</span>\n",
- "<span style=\"font-weight: bold\">}</span>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[1m{\u001b[0m\n",
- " \u001b[1;34m\"f#f\"\u001b[0m: \u001b[1;36m0.402447\u001b[0m,\n",
- " \u001b[1;34m\"pre#f\"\u001b[0m: \u001b[1;36m0.447906\u001b[0m,\n",
- " \u001b[1;34m\"rec#f\"\u001b[0m: \u001b[1;36m0.365365\u001b[0m\n",
- "\u001b[1m}\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:51:15] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> The best performance for monitor f#<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">f:0</span>.<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">402447</span> was <a href=\"file://../../fastNLP/core/callbacks/progress_callback.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">progress_callback.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/callbacks/progress_callback.py#37\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">37</span></a>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> achieved in Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Global Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">625</span>. The <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> evaluation result: <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'f#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.402447</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'pre#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.447906</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'rec#f'</span>: <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.365365</span><span style=\"font-weight: bold\">}</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[2;36m[10:51:15]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m The best performance for monitor f#\u001b[1;92mf:0\u001b[0m.\u001b[1;36m402447\u001b[0m was \u001b]8;id=192029;file://../../fastNLP/core/callbacks/progress_callback.py\u001b\\\u001b[2mprogress_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=994998;file://../../fastNLP/core/callbacks/progress_callback.py#37\u001b\\\u001b[2m37\u001b[0m\u001b]8;;\u001b\\\n",
- "\u001b[2;36m \u001b[0m achieved in Epoch:\u001b[1;36m1\u001b[0m, Global Batch:\u001b[1;36m625\u001b[0m. The \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m evaluation result: \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m \u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.402447\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.447906\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[2m \u001b[0m\n",
- "\u001b[2;36m \u001b[0m \u001b[1;36m0.365365\u001b[0m\u001b[1m}\u001b[0m \u001b[2m \u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
- ],
- "text/plain": []
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Loading best model from buffer with f#f: <a href=\"file://../../fastNLP/core/callbacks/load_best_model_callback.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">load_best_model_callback.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/callbacks/load_best_model_callback.py#115\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">115</span></a>\n",
- "<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.402447</span><span style=\"color: #808000; text-decoration-color: #808000\">...</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from buffer with f#f: \u001b]8;id=654516;file://../../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96586;file://../../fastNLP/core/callbacks/load_best_model_callback.py#115\u001b\\\u001b[2m115\u001b[0m\u001b]8;;\u001b\\\n",
- "\u001b[2;36m \u001b[0m \u001b[1;36m0.402447\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "from torch import optim\n",
- "from fastNLP import Trainer, LoadBestModelCallback, TorchWarmupCallback\n",
- "from fastNLP import SpanFPreRecMetric\n",
- "\n",
- "optimizer = optim.AdamW(model.parameters(), lr=2e-5)\n",
- "callbacks = [\n",
- " LoadBestModelCallback(), # 用于在训练结束之后加载性能最好的model的权重\n",
- " TorchWarmupCallback()\n",
- "] \n",
- "\n",
- "trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer, \n",
- " evaluate_dataloaders=dls['dev'], \n",
- " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n",
- " n_epochs=1, callbacks=callbacks, \n",
- " # 在评测时将 dataloader 中的 first_len 映射 seq_len, 因为 Accuracy.update 接口需要输入一个名为 seq_len 的参数\n",
- " evaluate_input_mapping={'first_len': 'seq_len'}, overfit_batches=0,\n",
- " device=0, monitor='f#f', fp16=False) # fp16 为 True 的话,将使用 float16 进行训练。\n",
- "trainer.run()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "c600a450",
- "metadata": {},
- "source": [
- "### Evaluator的使用\n",
- "fastNLP中用于评测数据的对象。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "id": "1b19f0ba",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Output()"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
- ],
- "text/plain": []
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'f#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.390326</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'pre#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.414741</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'rec#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.368626</span><span style=\"font-weight: bold\">}</span>\n",
- "</pre>\n"
- ],
- "text/plain": [
- "\u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.390326\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.414741\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[1;36m0.368626\u001b[0m\u001b[1m}\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/plain": [
- "{'f#f': 0.390326, 'pre#f': 0.414741, 'rec#f': 0.368626}"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from fastNLP import Evaluator\n",
- "from fastNLP import SpanFPreRecMetric\n",
- "\n",
- "evaluator = Evaluator(model=model, dataloaders=dls['test'], \n",
- " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n",
- " evaluate_input_mapping={'first_len': 'seq_len'}, \n",
- " device=0)\n",
- "evaluator.run()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "52f87770",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "f723fe399df34917875ad74c2542508c",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Output()"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 如果想评测一下使用 constrained decoding的性能,则可以通过传入 evaluate_fn 指定使用的函数\n",
- "def input_mapping(x):\n",
- " x['seq_len'] = x['first_len']\n",
- " return x\n",
- "evaluator = Evaluator(model=model, dataloaders=dls['test'], device=0,\n",
- " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))},\n",
- " evaluate_fn='constrained_decode',\n",
- " # 如果将 first_len 重新命名为了 seq_len, 将导致 constrained_decode 的输入缺少 first_len 参数,因此\n",
- " # 额外重复一下 'first_len': 'first_len',使得这个参数不会消失。\n",
- " evaluate_input_mapping=input_mapping)\n",
- "evaluator.run()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "419e718b",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.13"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
- }
|