{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ " 从这篇开始,我们将开启**`fastNLP v0.8 tutorial`的`example`系列**,在接下来的\n", "\n", " 每篇`tutorial`里,我们将会介绍`fastNLP v0.8`在一些自然语言处理任务上的应用" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# E1. 使用 Bert + fine-tuning 完成 SST2 分类\n", "\n", " 1 基础介绍:`GLUE`通用语言理解评估、`SST2`文本情感二分类数据集 \n", "\n", " 2 准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n", "\n", " 3 模型训练:加载`distilbert-base`、`fastNLP`参数匹配、`fine-tuning`" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "4.18.0\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader, Dataset\n", "\n", "import transformers\n", "from transformers import AutoTokenizer\n", "from transformers import AutoModelForSequenceClassification\n", "\n", "import sys\n", "sys.path.append('..')\n", "\n", "import fastNLP\n", "from fastNLP import Trainer\n", "from fastNLP.core.metrics import Accuracy\n", "\n", "print(transformers.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. 基础介绍:GLUE 通用语言理解评估、SST2 文本情感二分类数据集\n", "\n", " 本示例使用`GLUE`评估基准中的`SST2`数据集,通过`fine-tuning`方式\n", "\n", " 调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST2`\n", "\n", "**`GLUE`**,**全称`General Language Understanding Evaluation`**,**通用语言理解评估**,\n", "\n", " 包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n", "\n", " **`CoLA`**,文本分类任务,预测单句语法正误分类;**`SST2`**,文本分类任务,预测单句情感二分类\n", "\n", " **`MRPC`**,句对分类任务,预测句对语义一致性;**`STSB`**,相似度打分任务,预测句对语义相似度回归\n", "\n", " **`QQP`**,句对分类任务,预测问题对语义一致性;**`MNLI`**,文本推理任务,预测句对蕴含/矛盾/中立预测\n", "\n", " **`QNLI`/`RTE`/`WNLI`**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n", "\n", " 诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n", "\n", " 此处,我们使用`SST2`来训练`bert`,实现文本分类,其他任务描述见下图" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n", "\n", "task = 'sst2'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'acc#acc': 0.87156, 'total#acc': 872.0, 'correct#acc': 760.0}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluator.run()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 附:`DistilBertForSequenceClassification`模块结构\n", "\n", "```\n", "