{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "  从这篇开始,我们将开启**`fastNLP v0.8 tutorial`的`example`系列**,在接下来的\n", "\n", "  每篇`tutorial`里,我们将会介绍`fastNLP v0.8`在一些自然语言处理任务上的应用" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# E1. 使用 Bert + fine-tuning 完成 SST2 分类\n", "\n", "  1   基础介绍:`GLUE`通用语言理解评估、`SST2`文本情感二分类数据集 \n", "\n", "  2   准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n", "\n", "  3   模型训练:加载`distilbert-base`、`fastNLP`参数匹配、`fine-tuning`" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "4.18.0\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader, Dataset\n", "\n", "import transformers\n", "from transformers import AutoTokenizer\n", "from transformers import AutoModelForSequenceClassification\n", "\n", "import sys\n", "sys.path.append('..')\n", "\n", "import fastNLP\n", "from fastNLP import Trainer\n", "from fastNLP.core.metrics import Accuracy\n", "\n", "print(transformers.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. 基础介绍:GLUE 通用语言理解评估、SST2 文本情感二分类数据集\n", "\n", "  本示例使用`GLUE`评估基准中的`SST2`数据集,通过`fine-tuning`方式\n", "\n", "    调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST2`\n", "\n", "**`GLUE`**,**全称`General Language Understanding Evaluation`**,**通用语言理解评估**,\n", "\n", "  包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n", "\n", "    **`CoLA`**,文本分类任务,预测单句语法正误分类;**`SST2`**,文本分类任务,预测单句情感二分类\n", "\n", "    **`MRPC`**,句对分类任务,预测句对语义一致性;**`STSB`**,相似度打分任务,预测句对语义相似度回归\n", "\n", "    **`QQP`**,句对分类任务,预测问题对语义一致性;**`MNLI`**,文本推理任务,预测句对蕴含/矛盾/中立预测\n", "\n", "    **`QNLI`/`RTE`/`WNLI`**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n", "\n", "  诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n", "\n", "    此处,我们使用`SST2`来训练`bert`,实现文本分类,其他任务描述见下图" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n", "\n", "task = 'sst2'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "**`SST`**,**全称`Stanford Sentiment Treebank`**,**斯坦福情感树库**,**单句情感分类**数据集\n", "\n", "  包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n", "\n", "  数据集包括三部分:训练集 67350 条,开发集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", "\n", "对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST2`数据集,自动加载\n", "\n", "  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "adc9449171454f658285f220b70126e1", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'acc#acc': 0.87156, 'total#acc': 872.0, 'correct#acc': 760.0}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluator.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 附:`DistilBertForSequenceClassification`模块结构\n",
    "\n",
    "```\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}