Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
e402893017
2 changed files with 144 additions and 607 deletions
  1. +10
    -1
      tutorials/fastnlp_tutorial_e1.ipynb
  2. +134
    -606
      tutorials/fastnlp_tutorial_e2.ipynb

+ 10
- 1
tutorials/fastnlp_tutorial_e1.ipynb View File

@@ -233,7 +233,7 @@
}
],
"source": [
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n",
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n",
"\n",
"model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n",
"\n",
@@ -881,6 +881,15 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,


+ 134
- 606
tutorials/fastnlp_tutorial_e2.ipynb View File

@@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# E2. 使用 PrefixTuning 完成 SST2 分类"
"# E2. 使用 continuous prompt 完成 SST2 分类"
]
},
{
@@ -35,10 +35,11 @@
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.optim import AdamW\n",
"from torch.utils.data import DataLoader, Dataset\n",
"\n",
"import torch.nn as nn\n",
"\n",
"import transformers\n",
"from transformers import AutoTokenizer\n",
"from transformers import AutoModelForSequenceClassification\n",
@@ -48,7 +49,6 @@
"\n",
"import fastNLP\n",
"from fastNLP import Trainer\n",
"from fastNLP.core.utils.utils import dataclass_to_dict\n",
"from fastNLP.core.metrics import Accuracy\n",
"\n",
"print(transformers.__version__)"
@@ -69,6 +69,82 @@
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class ClassModel(nn.Module):\n",
" def __init__(self, model_checkpoint, num_labels, pre_seq_len):\n",
" nn.Module.__init__(self)\n",
" self.num_labels = num_labels\n",
" self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n",
" num_labels=num_labels)\n",
" self.embeddings = self.back_bone.get_input_embeddings()\n",
"\n",
" for param in self.back_bone.parameters():\n",
" param.requires_grad = False\n",
" \n",
" self.pre_seq_len = pre_seq_len\n",
" self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n",
" self.prefix_encoder = nn.Embedding(self.pre_seq_len, self.embeddings.embedding_dim)\n",
" \n",
" def get_prompt(self, batch_size):\n",
" prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.back_bone.device)\n",
" prompts = self.prefix_encoder(prefix_tokens)\n",
" return prompts\n",
"\n",
" def forward(self, input_ids, attention_mask, labels):\n",
" \n",
" batch_size = input_ids.shape[0]\n",
" raw_embedding = self.embeddings(input_ids)\n",
" \n",
" prompts = self.get_prompt(batch_size=batch_size)\n",
" inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)\n",
" prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.back_bone.device)\n",
" attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n",
"\n",
" outputs = self.back_bone(inputs_embeds=inputs_embeds, \n",
" attention_mask=attention_mask, labels=labels)\n",
" return outputs\n",
"\n",
" def train_step(self, input_ids, attention_mask, labels):\n",
" return {\"loss\": self(input_ids, attention_mask, labels).loss}\n",
"\n",
" def evaluate_step(self, input_ids, attention_mask, labels):\n",
" pred = self(input_ids, attention_mask, labels).logits\n",
" pred = torch.max(pred, dim=-1)[1]\n",
" return {\"pred\": pred, \"target\": labels}"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']\n",
"- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n",
"\n",
"model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint, pre_seq_len=16)\n",
"\n",
"# Generally, simple classification tasks prefer shorter prompts (less than 20)\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-3)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": false
},
@@ -84,7 +160,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "253d79d7a67e4dc88338448b5bcb3fb9",
"model_id": "1b73650d43f245ac8a5501dc91c6fe8c",
"version_major": 2,
"version_minor": 0
},
@@ -99,48 +175,9 @@
"source": [
"from datasets import load_dataset, load_metric\n",
"\n",
"dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
]
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n",
"dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)\n",
"\n",
"print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"task_to_keys = {\n",
" \"cola\": (\"sentence\", None),\n",
" \"mnli\": (\"premise\", \"hypothesis\"),\n",
" \"mnli-mm\": (\"premise\", \"hypothesis\"),\n",
" \"mrpc\": (\"sentence1\", \"sentence2\"),\n",
" \"qnli\": (\"question\", \"sentence\"),\n",
" \"qqp\": (\"question1\", \"question2\"),\n",
" \"rte\": (\"sentence1\", \"sentence2\"),\n",
" \"sst2\": (\"sentence\", None),\n",
" \"stsb\": (\"sentence1\", \"sentence2\"),\n",
" \"wnli\": (\"sentence1\", \"sentence2\"),\n",
"}\n",
"\n",
"sentence1_key, sentence2_key = task_to_keys[task]"
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)"
]
},
{
@@ -149,100 +186,38 @@
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sentence: hide new secretions from the parental units \n"
]
}
],
"source": [
"if sentence2_key is None:\n",
" print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n",
"else:\n",
" print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n",
" print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ca1fbe5e8eb059f3.arrow\n",
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-03661263fbf302f5.arrow\n",
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fbe8e7a4e4f18f45.arrow\n"
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-294e481a713c5754.arrow\n",
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ed9d9258aaf0fb54.arrow\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0be84915c90f460896b8e67299e09df4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def preprocess_function(examples):\n",
" if sentence2_key is None:\n",
" return tokenizer(examples[sentence1_key], truncation=True)\n",
" return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)\n",
" return tokenizer(examples['sentence'], truncation=True)\n",
"\n",
"encoded_dataset = dataset.map(preprocess_function, batched=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class ClassModel(nn.Module):\n",
" def __init__(self, num_labels, model_checkpoint):\n",
" nn.Module.__init__(self)\n",
" self.num_labels = num_labels\n",
" self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n",
" num_labels=num_labels)\n",
" self.loss_fn = nn.CrossEntropyLoss()\n",
"\n",
" def forward(self, input_ids, attention_mask):\n",
" return self.back_bone(input_ids, attention_mask)\n",
"\n",
" def train_step(self, input_ids, attention_mask, labels):\n",
" pred = self(input_ids, attention_mask).logits\n",
" return {\"loss\": self.loss_fn(pred, labels)}\n",
"\n",
" def evaluate_step(self, input_ids, attention_mask, labels):\n",
" pred = self(input_ids, attention_mask).logits\n",
" pred = torch.max(pred, dim=-1)[1]\n",
" return {\"pred\": pred, \"target\": labels}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']\n",
"- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n",
"\n",
"model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-5)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -261,7 +236,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -287,7 +262,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -301,7 +276,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@@ -319,443 +294,15 @@
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# help(model.back_bone.forward)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[21:00:11] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#592\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">592</span></a>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m[21:00:11]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=22992;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=669026;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.871875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">279.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.878125</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">281.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.871875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">279.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.903125</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">289.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.903125\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m289.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.871875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">279.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.890625</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">285.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">280.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">284.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
@@ -763,20 +310,9 @@
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">284.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
@@ -793,37 +329,23 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
}
],
"source": [
"trainer.run(num_eval_batch_per_dl=10)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.890625</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">285.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
@@ -840,20 +362,17 @@
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
"{'acc#acc': 0.644495, 'total#acc': 872.0, 'correct#acc': 562.0}"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "display_data"
"output_type": "execute_result"
}
],
"source": [
"trainer.run(num_eval_batch_per_dl=10)"
"trainer.evaluator.run()"
]
},
{
@@ -881,6 +400,15 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,


Loading…
Cancel
Save