{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ " 从这篇开始,我们将开启**`fastNLP v0.8 tutorial`的`example`系列**,在接下来的\n", "\n", " 每篇`tutorial`里,我们将会介绍`fastNLP v0.8`在自然语言处理任务上的应用实例" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# E1. 使用 Bert + fine-tuning 完成 SST-2 分类\n", "\n", " 1 基础介绍:`GLUE`通用语言理解评估、`SST-2`文本情感二分类数据集 \n", "\n", " 2 准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n", "\n", " 3 模型训练:加载`distilbert-base`、`fastNLP`参数匹配、`fine-tuning`" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "4.18.0\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader, Dataset\n", "\n", "import transformers\n", "from transformers import AutoTokenizer\n", "from transformers import AutoModelForSequenceClassification\n", "\n", "import sys\n", "sys.path.append('..')\n", "\n", "import fastNLP\n", "from fastNLP import Trainer\n", "from fastNLP import Accuracy\n", "\n", "print(transformers.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. 基础介绍:GLUE 通用语言理解评估、SST-2 文本情感二分类数据集\n", "\n", " 本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`fine-tuning`方式\n", "\n", " 调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST-2`\n", "\n", "**`GLUE`**,**全称`General Language Understanding Evaluation`**,**通用语言理解评估**,\n", "\n", " 包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n", "\n", " **`CoLA`**,文本分类任务,预测单句语法正误分类;**`SST-2`**,文本分类任务,预测单句情感二分类\n", "\n", " **`MRPC`**,句对分类任务,预测句对语义一致性;**`STS-B`**,相似度打分任务,预测句对语义相似度回归\n", "\n", " **`QQP`**,句对分类任务,预测问题对语义一致性;**`MNLI`**,文本推理任务,预测句对蕴含/矛盾/中立预测\n", "\n", " **`QNLI`/`RTE`/`WNLI`**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n", "\n", " 诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n", "\n", " 此处,我们使用`SST-2`来训练`bert`,实现文本分类,其他任务描述见下图" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n", "\n", "task = 'sst2'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
[09:12:45] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", "\n" ], "text/plain": [ "\u001b[2;36m[09:12:45]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=408427;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=303634;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.884375,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 283.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.878125,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 281.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.884375,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 283.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.9,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 288.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.9\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m288.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.8875,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 284.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.88125,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 282.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.88125\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m282.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.875,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 280.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.865625,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 277.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.865625\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m277.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.884375,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 283.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.878125,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 281.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'acc#acc': 0.884174, 'total#acc': 872.0, 'correct#acc': 771.0}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluator.run()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 附:`DistilBertForSequenceClassification`模块结构\n", "\n", "```\n", "