From 19a48c7101d8c95ade89bdf09ad907ee13d7e78b Mon Sep 17 00:00:00 2001 From: lxr-tech <1838593642@qq.com> Date: Fri, 27 May 2022 22:47:11 +0800 Subject: [PATCH 1/2] update example-12 lxr 220527 --- tutorials/fastnlp_tutorial_e1.ipynb | 11 +- tutorials/fastnlp_tutorial_e2.ipynb | 815 +++++++++--------------------------- 2 files changed, 218 insertions(+), 608 deletions(-) diff --git a/tutorials/fastnlp_tutorial_e1.ipynb b/tutorials/fastnlp_tutorial_e1.ipynb index 92a49925..628dd7ae 100644 --- a/tutorials/fastnlp_tutorial_e1.ipynb +++ b/tutorials/fastnlp_tutorial_e1.ipynb @@ -233,7 +233,7 @@ } ], "source": [ - "num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n", + "num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n", "\n", "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", "\n", @@ -881,6 +881,15 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } } }, "nbformat": 4, diff --git a/tutorials/fastnlp_tutorial_e2.ipynb b/tutorials/fastnlp_tutorial_e2.ipynb index 8e734f01..9185102f 100644 --- a/tutorials/fastnlp_tutorial_e2.ipynb +++ b/tutorials/fastnlp_tutorial_e2.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# E2. 使用 PrefixTuning 完成 SST2 分类" + "# E2. 使用 continuous prompt 完成 SST2 分类" ] }, { @@ -35,10 +35,12 @@ ], "source": [ "import torch\n", - "import torch.nn as nn\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader, Dataset\n", "\n", + "import torch.nn as nn\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "\n", "import transformers\n", "from transformers import AutoTokenizer\n", "from transformers import AutoModelForSequenceClassification\n", @@ -69,180 +71,226 @@ { "cell_type": "code", "execution_count": 3, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", - "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "253d79d7a67e4dc88338448b5bcb3fb9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/3 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from datasets import load_dataset, load_metric\n", - "\n", - "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" - ] - } - ], + "outputs": [], "source": [ - "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", + "class PromptEncoder(nn.Module):\n", + " def __init__(self, template, hidden_size):\n", + " nn.Module.__init__(self)\n", + " self.template = template\n", + " self.hidden_size = hidden_size\n", + " self.cloze_mask = [[1] * self.template[0] + [1] * self.template[1]]\n", + " self.cloze_mask = torch.LongTensor(self.cloze_mask).bool()\n", + "\n", + " self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0]))))\n", + " # embed\n", + " self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), hidden_size)\n", + " # LSTM\n", + " self.lstm_head = torch.nn.LSTM(input_size=hidden_size,\n", + " hidden_size=hidden_size // 2,\n", + " num_layers=2, dropout=0.0,\n", + " bidirectional=True, batch_first=True)\n", + " # MLP\n", + " self.mlp_head = nn.Sequential(nn.Linear(hidden_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_size, hidden_size))\n", + " print(\"init prompt encoder...\")\n", "\n", - "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))" + " def forward(self, device):\n", + " input_embeds = self.embedding(self.seq_indices.to(device)).unsqueeze(0)\n", + " output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze()\n", + " return output_embeds" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "task_to_keys = {\n", - " \"cola\": (\"sentence\", None),\n", - " \"mnli\": (\"premise\", \"hypothesis\"),\n", - " \"mnli-mm\": (\"premise\", \"hypothesis\"),\n", - " \"mrpc\": (\"sentence1\", \"sentence2\"),\n", - " \"qnli\": (\"question\", \"sentence\"),\n", - " \"qqp\": (\"question1\", \"question2\"),\n", - " \"rte\": (\"sentence1\", \"sentence2\"),\n", - " \"sst2\": (\"sentence\", None),\n", - " \"stsb\": (\"sentence1\", \"sentence2\"),\n", - " \"wnli\": (\"sentence1\", \"sentence2\"),\n", - "}\n", + "class ClassModel(nn.Module):\n", + " def __init__(self, num_labels, model_checkpoint, pseudo_token='[PROMPT]', template=(3, 3)):\n", + " nn.Module.__init__(self)\n", + " self.template = template\n", + " self.num_labels = num_labels\n", + " self.spell_length = sum(template)\n", + " self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", + " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", + " num_labels=num_labels)\n", + " for param in self.back_bone.parameters():\n", + " param.requires_grad = False\n", + " self.embeddings = self.back_bone.get_input_embeddings()\n", + " \n", + " self.hidden_size = self.embeddings.embedding_dim\n", + " self.tokenizer.add_special_tokens({'additional_special_tokens': [pseudo_token]})\n", + " self.pseudo_token_id = self.tokenizer.get_vocab()[pseudo_token]\n", + " self.pad_token_id = self.tokenizer.pad_token_id\n", + " \n", + " self.prompt_encoder = PromptEncoder(self.template, self.hidden_size)\n", + "\n", + " self.loss_fn = nn.CrossEntropyLoss()\n", + "\n", + " def get_query(self, query):\n", + " device = query.device\n", + " return torch.cat([torch.tensor([self.tokenizer.cls_token_id]).to(device), # [CLS]\n", + " torch.tensor([self.pseudo_token_id] * self.template[0]).to(device), # [PROMPT]\n", + " torch.tensor([self.tokenizer.mask_token_id]).to(device), # [MASK] \n", + " torch.tensor([self.pseudo_token_id] * self.template[1]).to(device), # [PROMPT]\n", + " query, \n", + " torch.tensor([self.tokenizer.sep_token_id]).to(device)], dim=0) # [SEP]\n", + "\n", + " def forward(self, input_ids):\n", + " input_ids = torch.stack([self.get_query(input_ids[i]) for i in range(len(input_ids))])\n", + " attention_mask = input_ids != self.pad_token_id\n", + " \n", + " bz = input_ids.shape[0]\n", + " inputs_embeds = input_ids.clone()\n", + " inputs_embeds[(input_ids == self.pseudo_token_id)] = self.tokenizer.unk_token_id\n", + " inputs_embeds = self.embeddings(inputs_embeds)\n", + "\n", + " blocked_indices = (input_ids == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1] # bz\n", + " replace_embeds = self.prompt_encoder(input_ids.device)\n", + " for bidx in range(bz):\n", + " for i in range(self.spell_length):\n", + " inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[i, :]\n", + " \n", + " return self.back_bone(inputs_embeds=inputs_embeds, attention_mask=attention_mask)\n", "\n", - "sentence1_key, sentence2_key = task_to_keys[task]" + " def train_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids).logits\n", + " return {\"loss\": self.loss_fn(pred, labels)}\n", + "\n", + " def evaluate_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids).logits\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": labels}" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight']\n", + "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'pre_classifier.weight', 'classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + }, + { "name": "stdout", "output_type": "stream", "text": [ - "Sentence: hide new secretions from the parental units \n" + "init prompt encoder...\n" ] } ], "source": [ - "if sentence2_key is None:\n", - " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n", - "else:\n", - " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n", - " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")" + "num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n", + "\n", + "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-4)" ] }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, + "execution_count": 6, + "metadata": { + "scrolled": false + }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ca1fbe5e8eb059f3.arrow\n", - "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-03661263fbf302f5.arrow\n", - "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fbe8e7a4e4f18f45.arrow\n" + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f82d2ccee863492582f94552654482f9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "def preprocess_function(examples):\n", - " if sentence2_key is None:\n", - " return tokenizer(examples[sentence1_key], truncation=True)\n", - " return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)\n", - "\n", - "encoded_dataset = dataset.map(preprocess_function, batched=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "class ClassModel(nn.Module):\n", - " def __init__(self, num_labels, model_checkpoint):\n", - " nn.Module.__init__(self)\n", - " self.num_labels = num_labels\n", - " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", - " num_labels=num_labels)\n", - " self.loss_fn = nn.CrossEntropyLoss()\n", - "\n", - " def forward(self, input_ids, attention_mask):\n", - " return self.back_bone(input_ids, attention_mask)\n", - "\n", - " def train_step(self, input_ids, attention_mask, labels):\n", - " pred = self(input_ids, attention_mask).logits\n", - " return {\"loss\": self.loss_fn(pred, labels)}\n", + "from datasets import load_dataset, load_metric\n", "\n", - " def evaluate_step(self, input_ids, attention_mask, labels):\n", - " pred = self(input_ids, attention_mask).logits\n", - " pred = torch.max(pred, dim=-1)[1]\n", - " return {\"pred\": pred, \"target\": labels}" + "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']\n", - "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", - "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n", - "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cf324902e7b94ea9be709b979b425c96", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/68 [00:00, ?ba/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "21eb6203ec6f4592b8cb8530a59eda49", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00, ?ba/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "05b83c4b1a9f44aea805788e1e52db78", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/2 [00:00, ?ba/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n", - "\n", - "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", + "def preprocess_function(examples):\n", + " return model.tokenizer(examples['sentence'], truncation=True)\n", "\n", - "optimizers = AdamW(params=model.parameters(), lr=5e-5)" + "encoded_dataset = dataset.map(preprocess_function, batched=True)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -261,7 +309,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -287,7 +335,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -301,7 +349,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -319,54 +367,15 @@ }, { "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "# help(model.back_bone.forward)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
[21:00:11] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", - "\n" - ], - "text/plain": [ - "\u001b[2;36m[21:00:11]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=22992;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=669026;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" + "\n" ], - "text/plain": [ - "\n" - ] + "text/plain": [] }, "metadata": {}, "output_type": "display_data" @@ -374,12 +383,9 @@ { "data": { "text/html": [ - "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", - "\n" + "\n" ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] + "text/plain": [] }, "metadata": {}, "output_type": "display_data" @@ -387,33 +393,32 @@ { "data": { "text/html": [ - "
{\n", - " \"acc#acc\": 0.871875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 279.0\n", - "}\n", + "\n", "\n" ], "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" + "\n" ] }, "metadata": {}, "output_type": "display_data" - }, + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ { "data": { "text/html": [ - "\n", - "\n" + "\n" ], - "text/plain": [ - "\n" - ] + "text/plain": [] }, "metadata": {}, "output_type": "display_data" @@ -421,439 +426,26 @@ { "data": { "text/html": [ - "----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", - "\n" + "\n" ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] + "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { - "text/html": [ - "{\n", - " \"acc#acc\": 0.878125,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 281.0\n", - "}\n", - "\n" - ], "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" + "{'acc#acc': 0.565367, 'total#acc': 872.0, 'correct#acc': 493.0}" ] }, + "execution_count": 13, "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "{\n", - " \"acc#acc\": 0.871875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 279.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "{\n", - " \"acc#acc\": 0.903125,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 289.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.903125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m289.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "{\n", - " \"acc#acc\": 0.871875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 279.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "{\n", - " \"acc#acc\": 0.890625,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 285.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "{\n", - " \"acc#acc\": 0.875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 280.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "{\n", - " \"acc#acc\": 0.8875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 284.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "{\n", - " \"acc#acc\": 0.8875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 284.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "{\n", - " \"acc#acc\": 0.890625,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 285.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" } ], "source": [ - "trainer.run(num_eval_batch_per_dl=10)" + "trainer.evaluator.run()" ] }, { @@ -881,6 +473,15 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } } }, "nbformat": 4, From ec76ba8887f3e3df9778ae3db36bc5320ec52f62 Mon Sep 17 00:00:00 2001 From: lxr-tech <1838593642@qq.com> Date: Sat, 28 May 2022 17:06:57 +0800 Subject: [PATCH 2/2] update example-2 lxr 220528 --- tutorials/fastnlp_tutorial_e2.ipynb | 181 +++++++++++------------------------- 1 file changed, 54 insertions(+), 127 deletions(-) diff --git a/tutorials/fastnlp_tutorial_e2.ipynb b/tutorials/fastnlp_tutorial_e2.ipynb index 9185102f..1d7746be 100644 --- a/tutorials/fastnlp_tutorial_e2.ipynb +++ b/tutorials/fastnlp_tutorial_e2.ipynb @@ -39,7 +39,6 @@ "from torch.utils.data import DataLoader, Dataset\n", "\n", "import torch.nn as nn\n", - "from torch.nn.utils.rnn import pad_sequence\n", "\n", "import transformers\n", "from transformers import AutoTokenizer\n", @@ -50,7 +49,6 @@ "\n", "import fastNLP\n", "from fastNLP import Trainer\n", - "from fastNLP.core.utils.utils import dataclass_to_dict\n", "from fastNLP.core.metrics import Accuracy\n", "\n", "print(transformers.__version__)" @@ -74,133 +72,79 @@ "metadata": {}, "outputs": [], "source": [ - "class PromptEncoder(nn.Module):\n", - " def __init__(self, template, hidden_size):\n", - " nn.Module.__init__(self)\n", - " self.template = template\n", - " self.hidden_size = hidden_size\n", - " self.cloze_mask = [[1] * self.template[0] + [1] * self.template[1]]\n", - " self.cloze_mask = torch.LongTensor(self.cloze_mask).bool()\n", - "\n", - " self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0]))))\n", - " # embed\n", - " self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), hidden_size)\n", - " # LSTM\n", - " self.lstm_head = torch.nn.LSTM(input_size=hidden_size,\n", - " hidden_size=hidden_size // 2,\n", - " num_layers=2, dropout=0.0,\n", - " bidirectional=True, batch_first=True)\n", - " # MLP\n", - " self.mlp_head = nn.Sequential(nn.Linear(hidden_size, hidden_size),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_size, hidden_size))\n", - " print(\"init prompt encoder...\")\n", - "\n", - " def forward(self, device):\n", - " input_embeds = self.embedding(self.seq_indices.to(device)).unsqueeze(0)\n", - " output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze()\n", - " return output_embeds" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ "class ClassModel(nn.Module):\n", - " def __init__(self, num_labels, model_checkpoint, pseudo_token='[PROMPT]', template=(3, 3)):\n", + " def __init__(self, model_checkpoint, num_labels, pre_seq_len):\n", " nn.Module.__init__(self)\n", - " self.template = template\n", " self.num_labels = num_labels\n", - " self.spell_length = sum(template)\n", - " self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", " num_labels=num_labels)\n", + " self.embeddings = self.back_bone.get_input_embeddings()\n", + "\n", " for param in self.back_bone.parameters():\n", " param.requires_grad = False\n", - " self.embeddings = self.back_bone.get_input_embeddings()\n", - " \n", - " self.hidden_size = self.embeddings.embedding_dim\n", - " self.tokenizer.add_special_tokens({'additional_special_tokens': [pseudo_token]})\n", - " self.pseudo_token_id = self.tokenizer.get_vocab()[pseudo_token]\n", - " self.pad_token_id = self.tokenizer.pad_token_id\n", " \n", - " self.prompt_encoder = PromptEncoder(self.template, self.hidden_size)\n", - "\n", - " self.loss_fn = nn.CrossEntropyLoss()\n", - "\n", - " def get_query(self, query):\n", - " device = query.device\n", - " return torch.cat([torch.tensor([self.tokenizer.cls_token_id]).to(device), # [CLS]\n", - " torch.tensor([self.pseudo_token_id] * self.template[0]).to(device), # [PROMPT]\n", - " torch.tensor([self.tokenizer.mask_token_id]).to(device), # [MASK] \n", - " torch.tensor([self.pseudo_token_id] * self.template[1]).to(device), # [PROMPT]\n", - " query, \n", - " torch.tensor([self.tokenizer.sep_token_id]).to(device)], dim=0) # [SEP]\n", + " self.pre_seq_len = pre_seq_len\n", + " self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n", + " self.prefix_encoder = nn.Embedding(self.pre_seq_len, self.embeddings.embedding_dim)\n", + " \n", + " def get_prompt(self, batch_size):\n", + " prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.back_bone.device)\n", + " prompts = self.prefix_encoder(prefix_tokens)\n", + " return prompts\n", "\n", - " def forward(self, input_ids):\n", - " input_ids = torch.stack([self.get_query(input_ids[i]) for i in range(len(input_ids))])\n", - " attention_mask = input_ids != self.pad_token_id\n", + " def forward(self, input_ids, attention_mask, labels):\n", " \n", - " bz = input_ids.shape[0]\n", - " inputs_embeds = input_ids.clone()\n", - " inputs_embeds[(input_ids == self.pseudo_token_id)] = self.tokenizer.unk_token_id\n", - " inputs_embeds = self.embeddings(inputs_embeds)\n", - "\n", - " blocked_indices = (input_ids == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1] # bz\n", - " replace_embeds = self.prompt_encoder(input_ids.device)\n", - " for bidx in range(bz):\n", - " for i in range(self.spell_length):\n", - " inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[i, :]\n", + " batch_size = input_ids.shape[0]\n", + " raw_embedding = self.embeddings(input_ids)\n", " \n", - " return self.back_bone(inputs_embeds=inputs_embeds, attention_mask=attention_mask)\n", + " prompts = self.get_prompt(batch_size=batch_size)\n", + " inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)\n", + " prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.back_bone.device)\n", + " attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n", + "\n", + " outputs = self.back_bone(inputs_embeds=inputs_embeds, \n", + " attention_mask=attention_mask, labels=labels)\n", + " return outputs\n", "\n", " def train_step(self, input_ids, attention_mask, labels):\n", - " pred = self(input_ids).logits\n", - " return {\"loss\": self.loss_fn(pred, labels)}\n", + " return {\"loss\": self(input_ids, attention_mask, labels).loss}\n", "\n", " def evaluate_step(self, input_ids, attention_mask, labels):\n", - " pred = self(input_ids).logits\n", + " pred = self(input_ids, attention_mask, labels).logits\n", " pred = torch.max(pred, dim=-1)[1]\n", " return {\"pred\": pred, \"target\": labels}" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight']\n", + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'pre_classifier.weight', 'classifier.bias']\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "init prompt encoder...\n" - ] } ], "source": [ "num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n", "\n", - "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", + "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint, pre_seq_len=16)\n", + "\n", + "# Generally, simple classification tasks prefer shorter prompts (less than 20)\n", "\n", - "optimizers = AdamW(params=model.parameters(), lr=5e-4)" + "optimizers = AdamW(params=model.parameters(), lr=5e-3)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": { "scrolled": false }, @@ -209,13 +153,14 @@ "name": "stderr", "output_type": "stream", "text": [ + "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f82d2ccee863492582f94552654482f9", + "model_id": "1b73650d43f245ac8a5501dc91c6fe8c", "version_major": 2, "version_minor": 0 }, @@ -230,46 +175,28 @@ "source": [ "from datasets import load_dataset, load_metric\n", "\n", - "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)" + "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cf324902e7b94ea9be709b979b425c96", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/68 [00:00, ?ba/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "21eb6203ec6f4592b8cb8530a59eda49", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1 [00:00, ?ba/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-294e481a713c5754.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ed9d9258aaf0fb54.arrow\n" + ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "05b83c4b1a9f44aea805788e1e52db78", + "model_id": "0be84915c90f460896b8e67299e09df4", "version_major": 2, "version_minor": 0 }, @@ -283,14 +210,14 @@ ], "source": [ "def preprocess_function(examples):\n", - " return model.tokenizer(examples['sentence'], truncation=True)\n", + " return tokenizer(examples['sentence'], truncation=True)\n", "\n", "encoded_dataset = dataset.map(preprocess_function, batched=True)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -309,7 +236,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -335,7 +262,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -349,7 +276,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -367,7 +294,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -410,7 +337,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -436,10 +363,10 @@ { "data": { "text/plain": [ - "{'acc#acc': 0.565367, 'total#acc': 872.0, 'correct#acc': 493.0}" + "{'acc#acc': 0.644495, 'total#acc': 872.0, 'correct#acc': 562.0}" ] }, - "execution_count": 13, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" }