diff --git a/tutorials/fastnlp_tutorial_1.ipynb b/tutorials/fastnlp_tutorial_1.ipynb
index 09e8821d..db77e6c3 100644
--- a/tutorials/fastnlp_tutorial_1.ipynb
+++ b/tutorials/fastnlp_tutorial_1.ipynb
@@ -1325,7 +1325,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.7.4"
+ "version": "3.7.13"
}
},
"nbformat": 4,
diff --git a/tutorials/fastnlp_tutorial_3.ipynb b/tutorials/fastnlp_tutorial_3.ipynb
index 8c3c935e..353e4645 100644
--- a/tutorials/fastnlp_tutorial_3.ipynb
+++ b/tutorials/fastnlp_tutorial_3.ipynb
@@ -288,7 +288,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.7.4"
+ "version": "3.7.13"
},
"pycharm": {
"stem_cell": {
diff --git a/tutorials/fastnlp_tutorial_e1.ipynb b/tutorials/fastnlp_tutorial_e1.ipynb
index 628dd7ae..6ec04cb4 100644
--- a/tutorials/fastnlp_tutorial_e1.ipynb
+++ b/tutorials/fastnlp_tutorial_e1.ipynb
@@ -4,7 +4,22 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# E1. 使用 DistilBert 完成 SST2 分类"
+ " 从这篇开始,我们将开启**`fastNLP v0.8 tutorial`的`example`系列**,在接下来的\n",
+ "\n",
+ " 每篇`tutorial`里,我们将会介绍`fastNLP v0.8`在一些自然语言处理任务上的应用"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# E1. 使用 Bert + fine-tuning 完成 SST2 分类\n",
+ "\n",
+ " 1 基础介绍:`GLUE`通用语言理解评估、`SST2`文本情感二分类数据集 \n",
+ "\n",
+ " 2 准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n",
+ "\n",
+ " 3 模型训练:加载`distilbert-base`、`fastNLP`参数匹配、`fine-tuning`"
]
},
{
@@ -48,22 +63,64 @@
"\n",
"import fastNLP\n",
"from fastNLP import Trainer\n",
- "from fastNLP.core.utils.utils import dataclass_to_dict\n",
"from fastNLP.core.metrics import Accuracy\n",
"\n",
"print(transformers.__version__)"
]
},
{
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1. 基础介绍:GLUE 通用语言理解评估、SST2 文本情感二分类数据集\n",
+ "\n",
+ " 本示例使用`GLUE`评估基准中的`SST2`数据集,通过`fine-tuning`方式\n",
+ "\n",
+ " 调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST2`\n",
+ "\n",
+ "**`GLUE`**,**全称`General Language Understanding Evaluation`**,**通用语言理解评估**,\n",
+ "\n",
+ " 包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n",
+ "\n",
+ " **`CoLA`**,文本分类任务,预测单句语法正误分类;**`SST2`**,文本分类任务,预测单句情感二分类\n",
+ "\n",
+ " **`MRPC`**,句对分类任务,预测句对语义一致性;**`STSB`**,相似度打分任务,预测句对语义相似度回归\n",
+ "\n",
+ " **`QQP`**,句对分类任务,预测问题对语义一致性;**`MNLI`**,文本推理任务,预测句对蕴含/矛盾/中立预测\n",
+ "\n",
+ " **`QNLI`/`RTE`/`WNLI`**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n",
+ "\n",
+ " 诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n",
+ "\n",
+ " 此处,我们使用`SST2`来训练`bert`,实现文本分类,其他任务描述见下图"
+ ]
+ },
+ {
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
- "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]\n",
+ "GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n",
+ "\n",
+ "task = 'sst2'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "**`SST`**,**全称`Stanford Sentiment Treebank`**,**斯坦福情感树库**,**单句情感分类**数据集\n",
+ "\n",
+ " 包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n",
+ "\n",
+ " 数据集包括三部分:训练集 67350 条,开发集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n",
"\n",
- "task = \"sst2\"\n",
- "model_checkpoint = \"distilbert-base-uncased\""
+ "对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST2`数据集,自动加载\n",
+ "\n",
+ " 首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下"
]
},
{
@@ -84,7 +141,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "253d79d7a67e4dc88338448b5bcb3fb9",
+ "model_id": "adc9449171454f658285f220b70126e1",
"version_major": 2,
"version_minor": 0
},
@@ -97,9 +154,16 @@
}
],
"source": [
- "from datasets import load_dataset, load_metric\n",
+ "from datasets import load_dataset\n",
"\n",
- "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)"
+ "dataset = load_dataset('glue', task)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " 加载之后,根据`GLUE`中`SST2`数据集的格式,尝试打印部分数据,检查加载结果"
]
},
{
@@ -111,62 +175,89 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
+ "Sentence: hide new secretions from the parental units \n"
]
}
],
"source": [
- "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n",
+ "task_to_keys = {\n",
+ " 'cola': ('sentence', None),\n",
+ " 'mnli': ('premise', 'hypothesis'),\n",
+ " 'mnli': ('premise', 'hypothesis'),\n",
+ " 'mrpc': ('sentence1', 'sentence2'),\n",
+ " 'qnli': ('question', 'sentence'),\n",
+ " 'qqp': ('question1', 'question2'),\n",
+ " 'rte': ('sentence1', 'sentence2'),\n",
+ " 'sst2': ('sentence', None),\n",
+ " 'stsb': ('sentence1', 'sentence2'),\n",
+ " 'wnli': ('sentence1', 'sentence2'),\n",
+ "}\n",
"\n",
- "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))"
+ "sentence1_key, sentence2_key = task_to_keys[task]\n",
+ "\n",
+ "if sentence2_key is None:\n",
+ " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n",
+ "else:\n",
+ " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n",
+ " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")"
]
},
{
- "cell_type": "code",
- "execution_count": 5,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "task_to_keys = {\n",
- " \"cola\": (\"sentence\", None),\n",
- " \"mnli\": (\"premise\", \"hypothesis\"),\n",
- " \"mnli-mm\": (\"premise\", \"hypothesis\"),\n",
- " \"mrpc\": (\"sentence1\", \"sentence2\"),\n",
- " \"qnli\": (\"question\", \"sentence\"),\n",
- " \"qqp\": (\"question1\", \"question2\"),\n",
- " \"rte\": (\"sentence1\", \"sentence2\"),\n",
- " \"sst2\": (\"sentence\", None),\n",
- " \"stsb\": (\"sentence1\", \"sentence2\"),\n",
- " \"wnli\": (\"sentence1\", \"sentence2\"),\n",
- "}\n",
+ "### 2. 准备工作:加载 tokenizer、预处理 dataset、dataloader 使用\n",
+ "\n",
+ " 接下来进入模型训练的准备工作,分别需要使用`tokenizer`模块对数据集进行分词与标注\n",
+ "\n",
+ " 定义`SeqClsDataset`对应`dataloader`模块用来实现数据集在训练/测试时的加载\n",
+ "\n",
+ "此处的`tokenizer`和`SequenceClassificationModel`都是基于**`distilbert-base-uncased`模型**\n",
"\n",
- "sentence1_key, sentence2_key = task_to_keys[task]"
+ " 即使用较小的、不区分大小写的数据集,**对`bert-base`进行知识蒸馏后的版本**,结构上\n",
+ "\n",
+ " 模型包含1个编码层、6个自注意力层,详解见本篇末尾,更多细节请参考[DistilBert论文](https://arxiv.org/pdf/1910.01108.pdf)\n",
+ "\n",
+ "首先,通过从`transformers`库中导入`AutoTokenizer`模块,使用`from_pretrained`函数初始化\n",
+ "\n",
+ " 此处的`use_fast`表示是否使用`tokenizer`的快速版本;尝试序列化示例数据,检查加载结果\n",
+ "\n",
+ " 需要注意的是,处理后返回的两个键值,`'input_ids'`表示原始文本对应的词素编号序列\n",
+ "\n",
+ " `'attention_mask'`表示自注意力运算时的掩模(标上`0`的部分对应`padding`的内容"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Sentence: hide new secretions from the parental units \n"
+ "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
]
}
],
"source": [
- "if sentence2_key is None:\n",
- " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n",
- "else:\n",
- " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n",
- " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")"
+ "model_checkpoint = 'distilbert-base-uncased'\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n",
+ "\n",
+ "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "接着,定义预处理函数,**通过`dataset.map`方法**,**将数据集中的文本**,**替换为词素编号序列**"
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -189,66 +280,27 @@
]
},
{
- "cell_type": "code",
- "execution_count": 8,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "class ClassModel(nn.Module):\n",
- " def __init__(self, num_labels, model_checkpoint):\n",
- " nn.Module.__init__(self)\n",
- " self.num_labels = num_labels\n",
- " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n",
- " num_labels=num_labels)\n",
- " self.loss_fn = nn.CrossEntropyLoss()\n",
- "\n",
- " def forward(self, input_ids, attention_mask):\n",
- " return self.back_bone(input_ids, attention_mask)\n",
+ "然后,通过继承`torch`中的`Dataset`类,定义`SeqClsDataset`类,需要注意的是\n",
"\n",
- " def train_step(self, input_ids, attention_mask, labels):\n",
- " pred = self(input_ids, attention_mask).logits\n",
- " return {\"loss\": self.loss_fn(pred, labels)}\n",
- "\n",
- " def evaluate_step(self, input_ids, attention_mask, labels):\n",
- " pred = self(input_ids, attention_mask).logits\n",
- " pred = torch.max(pred, dim=-1)[1]\n",
- " return {\"pred\": pred, \"target\": labels}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']\n",
- "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
- "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
- "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n",
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
- ]
- }
- ],
- "source": [
- "num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n",
+ " 其中,**`__getitem__`函数各返回值引用的键值**,**必须和原始数据集中的属性对应**\n",
"\n",
- "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n",
+ " 例如,`'label'`是`SST2`数据集中原有的内容(包括`'sentence'`和`'label'`\n",
"\n",
- "optimizers = AdamW(params=model.parameters(), lr=5e-5)"
+ " `'input_ids'`和`'attention_mask'`则是`tokenizer`处理后添加的字段"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
- "class TestDistilBertDataset(Dataset):\n",
+ "class SeqClsDataset(Dataset):\n",
" def __init__(self, dataset):\n",
- " super(TestDistilBertDataset, self).__init__()\n",
+ " Dataset.__init__(self)\n",
" self.dataset = dataset\n",
"\n",
" def __len__(self):\n",
@@ -256,16 +308,27 @@
"\n",
" def __getitem__(self, item):\n",
" item = self.dataset[item]\n",
- " return item[\"input_ids\"], item[\"attention_mask\"], [item[\"label\"]] "
+ " return item['input_ids'], item['attention_mask'], [item['label']] "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "再然后,**定义校对函数`collate_fn`对齐同个`batch`内的每笔数据**,需要注意的是该函数的\n",
+ "\n",
+ " **返回值必须是字典**,**键值必须同待训练模型的`train_step`和`evaluate_step`函数的参数**\n",
+ "\n",
+ " **相对应**;这也就是在`tutorial-0`中便被强调的,`fastNLP v0.8`的第一条**参数匹配**机制"
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
- "def test_bert_collate_fn(batch):\n",
+ "def collate_fn(batch):\n",
" input_ids, atten_mask, labels = [], [], []\n",
" max_length = [0] * 3\n",
" for each_item in batch:\n",
@@ -280,35 +343,136 @@
" each = (input_ids, atten_mask, labels)[i]\n",
" for item in each:\n",
" item.extend([0] * (max_length[i] - len(item)))\n",
- " return {\"input_ids\": torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n",
- " \"attention_mask\": torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n",
- " \"labels\": torch.cat([torch.tensor(item) for item in labels], dim=0)}"
+ " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n",
+ " 'attention_mask': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n",
+ " 'labels': torch.cat([torch.tensor(item) for item in labels], dim=0)}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "最后,分别对`tokenizer`处理过的训练集数据、验证集数据,进行预处理和批量划分"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
- "dataset_train = TestDistilBertDataset(encoded_dataset[\"train\"])\n",
+ "dataset_train = SeqClsDataset(encoded_dataset['train'])\n",
"dataloader_train = DataLoader(dataset=dataset_train, \n",
- " batch_size=32, shuffle=True, collate_fn=test_bert_collate_fn)\n",
- "dataset_valid = TestDistilBertDataset(encoded_dataset[\"validation\"])\n",
+ " batch_size=32, shuffle=True, collate_fn=collate_fn)\n",
+ "dataset_valid = SeqClsDataset(encoded_dataset['validation'])\n",
"dataloader_valid = DataLoader(dataset=dataset_valid, \n",
- " batch_size=32, shuffle=False, collate_fn=test_bert_collate_fn)"
+ " batch_size=32, shuffle=False, collate_fn=collate_fn)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3. 模型训练:加载 distilbert-base、fastNLP 参数匹配、fine-tuning\n",
+ "\n",
+ " 最后就是模型训练的,分别需要使用`distilbert-base-uncased`搭建分类模型\n",
+ "\n",
+ " 初始化优化器`optimizer`、训练模块`trainer`,通过`run`函数完成训练\n",
+ "\n",
+ "此处使用的`nn.Module`模块搭建模型,与`tokenizer`类似,通过从`transformers`库中\n",
+ "\n",
+ " 导入`AutoModelForSequenceClassification`模块,基于`distilbert-base-uncased`模型初始\n",
+ "\n",
+ "需要注意的是**`AutoModelForSequenceClassification`模块的输入参数和输出结构**\n",
+ "\n",
+ " 一方面,可以**通过输入标签值`labels`**,**使用模块内的损失函数计算损失`loss`**\n",
+ "\n",
+ " 并且可以选择输入是词素编号序列`input_ids`,还是词素嵌入序列`inputs_embeds`\n",
+ "\n",
+ " 另方面,该模块不会直接输出预测结果,而是会**输出各预测分类上的几率`logits`**\n",
+ "\n",
+ " 基于上述描述,此处完成了中`train_step`和`evaluate_step`函数的定义\n",
+ "\n",
+ " 同样需要注意,函数的返回值体现了`fastNLP v0.8`的第二条**参数匹配**机制"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class SeqClsModel(nn.Module):\n",
+ " def __init__(self, num_labels, model_checkpoint):\n",
+ " nn.Module.__init__(self)\n",
+ " self.num_labels = num_labels\n",
+ " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n",
+ " num_labels=num_labels)\n",
+ "\n",
+ " def forward(self, input_ids, attention_mask, labels=None):\n",
+ " output = self.back_bone(input_ids=input_ids, \n",
+ " attention_mask=attention_mask, labels=labels)\n",
+ " return output\n",
+ "\n",
+ " def train_step(self, input_ids, attention_mask, labels):\n",
+ " loss = self(input_ids, attention_mask, labels).loss\n",
+ " return {'loss': loss}\n",
+ "\n",
+ " def evaluate_step(self, input_ids, attention_mask, labels):\n",
+ " pred = self(input_ids, attention_mask, labels).logits\n",
+ " pred = torch.max(pred, dim=-1)[1]\n",
+ " return {'pred': pred, 'target': labels}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight']\n",
+ "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'classifier.bias', 'pre_classifier.weight']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ }
+ ],
+ "source": [
+ "num_labels = 3 if task == 'mnli' else 1 if task == 'stsb' else 2\n",
+ "\n",
+ "model = SeqClsModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n",
+ "\n",
+ "optimizers = AdamW(params=model.parameters(), lr=5e-5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "然后,使用之前完成的`dataloader_train`和`dataloader_valid`,定义训练模块`trainer`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model=model,\n",
" driver='torch',\n",
- " device='cuda',\n",
+ " device=1, # 'cuda'\n",
" n_epochs=10,\n",
" optimizers=optimizers,\n",
" train_dataloader=dataloader_train,\n",
@@ -318,42 +482,35 @@
]
},
{
- "cell_type": "code",
- "execution_count": 14,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "# help(model.back_bone.forward)"
+ "最后,使用`trainer.run`方法,训练模型,`n_epochs`参数中已经指定需要迭代`10`轮\n",
+ "\n",
+ " `num_eval_batch_per_dl`参数则指定每次只对验证集中的`10`个`batch`进行评估"
]
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
[21:00:11] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", - "\n" + "\n" ], - "text/plain": [ - "\u001b[2;36m[21:00:11]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=22992;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=669026;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" - ] + "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] + "text/html": [ + "\n" + ], + "text/plain": [] }, "metadata": {}, "output_type": "display_data" @@ -370,16 +527,23 @@ }, "metadata": {}, "output_type": "display_data" - }, + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ { "data": { "text/html": [ - "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", - "\n" + "\n" ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] + "text/plain": [] }, "metadata": {}, "output_type": "display_data" @@ -387,473 +551,155 @@ { "data": { "text/html": [ - "
{\n", - " \"acc#acc\": 0.871875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 279.0\n", - "}\n", - "\n" + "\n" ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] + "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.878125,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 281.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.871875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 279.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.903125,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 289.0\n", - "}\n", - "\n" - ], "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.903125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m289.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" + "{'acc#acc': 0.87156, 'total#acc': 872.0, 'correct#acc': 760.0}" ] }, + "execution_count": 14, "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.871875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 279.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.890625,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 285.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 280.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.8875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 284.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.8875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 284.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.890625,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 285.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" } ], "source": [ - "trainer.run(num_eval_batch_per_dl=10)" + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 附:`DistilBertForSequenceClassification`模块结构\n", + "\n", + "```\n", + "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " + ] + }, + { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "{'acc#acc': 0.644495, 'total#acc': 872.0, 'correct#acc': 562.0}" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "trainer.evaluator.run()" ] diff --git a/tutorials/figures/E1-fig-glue-benchmark.png b/tutorials/figures/E1-fig-glue-benchmark.png new file mode 100644 index 00000000..515db700 Binary files /dev/null and b/tutorials/figures/E1-fig-glue-benchmark.png differ diff --git a/tutorials/figures/E2-fig-p-tuning-v2-model.png b/tutorials/figures/E2-fig-p-tuning-v2-model.png new file mode 100644 index 00000000..b5a9c1b8 Binary files /dev/null and b/tutorials/figures/E2-fig-p-tuning-v2-model.png differ